Skip to main content

uni_query/query/df_graph/
locy_fixpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Fixpoint iteration operator for recursive Locy strata.
5//!
6//! `FixpointExec` drives semi-naive evaluation: it repeatedly evaluates the rules
7//! in a recursive stratum, feeding back deltas until no new facts are produced.
8
9use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11    ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12    extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17    ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::planner::LogicalPlan;
22use arrow_array::RecordBatch;
23use arrow_row::{RowConverter, SortField};
24use arrow_schema::SchemaRef;
25use datafusion::common::JoinType;
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
29use datafusion::physical_plan::memory::MemoryStream;
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
32use futures::Stream;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::pin::Pin;
38use std::sync::{Arc, RwLock as StdRwLock};
39use std::task::{Context, Poll};
40use std::time::{Duration, Instant};
41use uni_common::Value;
42use uni_common::core::schema::Schema as UniSchema;
43use uni_locy::RuntimeWarning;
44use uni_store::storage::manager::StorageManager;
45
46// ---------------------------------------------------------------------------
47// DerivedScanRegistry — injection point for IS-ref data into subplans
48// ---------------------------------------------------------------------------
49
50/// A single entry in the derived scan registry.
51///
52/// Each entry corresponds to one `LocyDerivedScan` node in the logical plan tree.
53/// The `data` handle is shared with the logical plan node so that writing data here
54/// makes it visible when the subplan is re-planned and executed.
55#[derive(Debug)]
56pub struct DerivedScanEntry {
57    /// Index matching the `scan_index` in `LocyDerivedScan`.
58    pub scan_index: usize,
59    /// Name of the rule this scan reads from.
60    pub rule_name: String,
61    /// Whether this is a self-referential scan (rule references itself).
62    pub is_self_ref: bool,
63    /// Shared data handle — write batches here to inject into subplans.
64    pub data: Arc<RwLock<Vec<RecordBatch>>>,
65    /// Schema of the derived relation.
66    pub schema: SchemaRef,
67}
68
69/// Registry of derived scan handles for fixpoint iteration.
70///
71/// During fixpoint, each clause body may reference derived relations via
72/// `LocyDerivedScan` nodes. The registry maps scan indices to shared data
73/// handles so the fixpoint loop can inject delta/full facts before each
74/// iteration.
75#[derive(Debug, Default)]
76pub struct DerivedScanRegistry {
77    entries: Vec<DerivedScanEntry>,
78}
79
80impl DerivedScanRegistry {
81    /// Create a new empty registry.
82    pub fn new() -> Self {
83        Self::default()
84    }
85
86    /// Add an entry to the registry.
87    pub fn add(&mut self, entry: DerivedScanEntry) {
88        self.entries.push(entry);
89    }
90
91    /// Get an entry by scan index.
92    pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
93        self.entries.iter().find(|e| e.scan_index == scan_index)
94    }
95
96    /// Write data into a scan entry's shared handle.
97    pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
98        if let Some(entry) = self.get(scan_index) {
99            let mut guard = entry.data.write();
100            *guard = batches;
101        }
102    }
103
104    /// Get all entries for a given rule name.
105    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
106        self.entries
107            .iter()
108            .filter(|e| e.rule_name == rule_name)
109            .collect()
110    }
111}
112
113// ---------------------------------------------------------------------------
114// MonotonicAggState — tracking monotonic aggregates across iterations
115// ---------------------------------------------------------------------------
116
117/// Monotonic aggregate binding: maps a fold name to its aggregate kind and column.
118#[derive(Debug, Clone)]
119pub struct MonotonicFoldBinding {
120    pub fold_name: String,
121    pub kind: crate::query::df_graph::locy_fold::FoldAggKind,
122    pub input_col_index: usize,
123}
124
125/// Tracks monotonic aggregate accumulators across fixpoint iterations.
126///
127/// After each iteration, accumulators are updated and compared to their previous
128/// snapshot. The fixpoint has converged (w.r.t. aggregates) when all accumulators
129/// are stable (no change between iterations).
130#[derive(Debug)]
131pub struct MonotonicAggState {
132    /// Current accumulator values keyed by (group_key, fold_name).
133    accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
134    /// Snapshot from the previous iteration for stability check.
135    prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
136    /// Bindings describing which aggregates to track.
137    bindings: Vec<MonotonicFoldBinding>,
138}
139
140impl MonotonicAggState {
141    /// Create a new monotonic aggregate state.
142    pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
143        Self {
144            accumulators: HashMap::new(),
145            prev_snapshot: HashMap::new(),
146            bindings,
147        }
148    }
149
150    /// Update accumulators with new delta batches.
151    ///
152    /// Returns `true` if any accumulator value changed. When `strict` is
153    /// `true`, Nor/Prod inputs outside `[0, 1]` produce an error instead
154    /// of being clamped.
155    pub fn update(
156        &mut self,
157        key_indices: &[usize],
158        delta_batches: &[RecordBatch],
159        strict: bool,
160    ) -> DFResult<bool> {
161        use crate::query::df_graph::locy_fold::FoldAggKind;
162
163        let mut changed = false;
164        for batch in delta_batches {
165            for row_idx in 0..batch.num_rows() {
166                let group_key = extract_scalar_key(batch, key_indices, row_idx);
167                for binding in &self.bindings {
168                    let col = batch.column(binding.input_col_index);
169                    let val = extract_f64(col.as_ref(), row_idx);
170                    if let Some(val) = val {
171                        let map_key = (group_key.clone(), binding.fold_name.clone());
172                        let entry = self
173                            .accumulators
174                            .entry(map_key)
175                            .or_insert(binding.kind.identity().unwrap_or(0.0));
176                        let old = *entry;
177                        match binding.kind {
178                            FoldAggKind::Sum | FoldAggKind::Count => *entry += val,
179                            FoldAggKind::Max => {
180                                if val > *entry {
181                                    *entry = val;
182                                }
183                            }
184                            FoldAggKind::Min => {
185                                if val < *entry {
186                                    *entry = val;
187                                }
188                            }
189                            FoldAggKind::Nor => {
190                                if strict && !(0.0..=1.0).contains(&val) {
191                                    return Err(datafusion::error::DataFusionError::Execution(
192                                        format!(
193                                            "strict_probability_domain: MNOR input {val} is outside [0, 1]"
194                                        ),
195                                    ));
196                                }
197                                if !strict && !(0.0..=1.0).contains(&val) {
198                                    tracing::warn!(
199                                        "MNOR input {val} outside [0,1], clamped to {}",
200                                        val.clamp(0.0, 1.0)
201                                    );
202                                }
203                                let p = val.clamp(0.0, 1.0);
204                                *entry = 1.0 - (1.0 - *entry) * (1.0 - p);
205                            }
206                            FoldAggKind::Prod => {
207                                if strict && !(0.0..=1.0).contains(&val) {
208                                    return Err(datafusion::error::DataFusionError::Execution(
209                                        format!(
210                                            "strict_probability_domain: MPROD input {val} is outside [0, 1]"
211                                        ),
212                                    ));
213                                }
214                                if !strict && !(0.0..=1.0).contains(&val) {
215                                    tracing::warn!(
216                                        "MPROD input {val} outside [0,1], clamped to {}",
217                                        val.clamp(0.0, 1.0)
218                                    );
219                                }
220                                let p = val.clamp(0.0, 1.0);
221                                *entry *= p;
222                            }
223                            _ => {}
224                        }
225                        if (*entry - old).abs() > f64::EPSILON {
226                            changed = true;
227                        }
228                    }
229                }
230            }
231        }
232        Ok(changed)
233    }
234
235    /// Take a snapshot of current accumulators for stability comparison.
236    pub fn snapshot(&mut self) {
237        self.prev_snapshot = self.accumulators.clone();
238    }
239
240    /// Check if accumulators are stable (no change since last snapshot).
241    pub fn is_stable(&self) -> bool {
242        if self.accumulators.len() != self.prev_snapshot.len() {
243            return false;
244        }
245        for (key, val) in &self.accumulators {
246            match self.prev_snapshot.get(key) {
247                Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
248                _ => return false,
249            }
250        }
251        true
252    }
253
254    /// Test-only accessor for accumulator values.
255    #[cfg(test)]
256    pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
257        self.accumulators.get(key).copied()
258    }
259}
260
261/// Extract f64 value from an Arrow column at a given row index.
262fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
263    if col.is_null(row_idx) {
264        return None;
265    }
266    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
267        Some(arr.value(row_idx))
268    } else {
269        col.as_any()
270            .downcast_ref::<arrow_array::Int64Array>()
271            .map(|arr| arr.value(row_idx) as f64)
272    }
273}
274
275// ---------------------------------------------------------------------------
276// RowDedupState — Arrow RowConverter-based persistent dedup set
277// ---------------------------------------------------------------------------
278
279/// Arrow-native row deduplication using [`RowConverter`].
280///
281/// Unlike the legacy `HashSet<Vec<ScalarKey>>` approach, this struct maintains a
282/// persistent `seen` set across iterations so per-iteration cost is O(M) where M
283/// is the number of candidate rows — the full facts table is never re-scanned.
284struct RowDedupState {
285    converter: RowConverter,
286    seen: HashSet<Box<[u8]>>,
287}
288
289impl RowDedupState {
290    /// Try to build a `RowDedupState` for the given schema.
291    ///
292    /// Returns `None` if any column type is not supported by `RowConverter`
293    /// (triggers legacy fallback).
294    fn try_new(schema: &SchemaRef) -> Option<Self> {
295        let fields: Vec<SortField> = schema
296            .fields()
297            .iter()
298            .map(|f| SortField::new(f.data_type().clone()))
299            .collect();
300        match RowConverter::new(fields) {
301            Ok(converter) => Some(Self {
302                converter,
303                seen: HashSet::new(),
304            }),
305            Err(e) => {
306                tracing::warn!(
307                    "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
308                    e
309                );
310                None
311            }
312        }
313    }
314
315    /// Populate the seen set from existing fact batches.
316    ///
317    /// Used after BEST BY in-loop pruning replaces the fact set, so that delta
318    /// computation in subsequent iterations correctly recognizes surviving facts.
319    fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
320        self.seen.clear();
321        for batch in facts {
322            if batch.num_rows() == 0 {
323                continue;
324            }
325            let arrays: Vec<_> = batch.columns().to_vec();
326            if let Ok(rows) = self.converter.convert_columns(&arrays) {
327                for row_idx in 0..batch.num_rows() {
328                    let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
329                    self.seen.insert(row_bytes);
330                }
331            }
332        }
333    }
334
335    /// Filter `candidates` to only rows not yet seen, updating the persistent set.
336    ///
337    /// Both cross-iteration dedup (rows already accepted in prior iterations) and
338    /// within-batch dedup (duplicate rows in a single candidate batch) are handled
339    /// in a single pass.
340    fn compute_delta(
341        &mut self,
342        candidates: &[RecordBatch],
343        schema: &SchemaRef,
344    ) -> DFResult<Vec<RecordBatch>> {
345        let mut delta_batches = Vec::new();
346        for batch in candidates {
347            if batch.num_rows() == 0 {
348                continue;
349            }
350
351            // Vectorized encoding of all rows in this batch.
352            let arrays: Vec<_> = batch.columns().to_vec();
353            let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
354
355            // One pass: check+insert into persistent seen set.
356            let mut keep = Vec::with_capacity(batch.num_rows());
357            for row_idx in 0..batch.num_rows() {
358                let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
359                keep.push(self.seen.insert(row_bytes));
360            }
361
362            let keep_mask = arrow_array::BooleanArray::from(keep);
363            let new_cols = batch
364                .columns()
365                .iter()
366                .map(|col| {
367                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
368                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
369                    })
370                })
371                .collect::<DFResult<Vec<_>>>()?;
372
373            if new_cols.first().is_some_and(|c| !c.is_empty()) {
374                let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
375                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
376                })?;
377                delta_batches.push(filtered);
378            }
379        }
380        Ok(delta_batches)
381    }
382}
383
384// ---------------------------------------------------------------------------
385// FixpointState — per-rule delta tracking during fixpoint iteration
386// ---------------------------------------------------------------------------
387
388/// Per-rule state for fixpoint iteration.
389///
390/// Tracks accumulated facts and the delta (new facts from the latest iteration).
391/// Deduplication uses Arrow [`RowConverter`] with a persistent seen set (O(M) per
392/// iteration) when supported, with a legacy `HashSet<Vec<ScalarKey>>` fallback.
393pub struct FixpointState {
394    rule_name: String,
395    facts: Vec<RecordBatch>,
396    delta: Vec<RecordBatch>,
397    schema: SchemaRef,
398    key_column_indices: Vec<usize>,
399    /// All column indices for full-row dedup (legacy path only).
400    all_column_indices: Vec<usize>,
401    /// Running total of facts bytes for memory limit tracking.
402    facts_bytes: usize,
403    /// Maximum bytes allowed for this derived relation.
404    max_derived_bytes: usize,
405    /// Optional monotonic aggregate tracking.
406    monotonic_agg: Option<MonotonicAggState>,
407    /// Arrow RowConverter-based dedup state; `None` triggers legacy fallback.
408    row_dedup: Option<RowDedupState>,
409    /// Whether strict probability domain checks are enabled.
410    strict_probability_domain: bool,
411}
412
413impl FixpointState {
414    /// Create a new fixpoint state for a rule.
415    pub fn new(
416        rule_name: String,
417        schema: SchemaRef,
418        key_column_indices: Vec<usize>,
419        max_derived_bytes: usize,
420        monotonic_agg: Option<MonotonicAggState>,
421        strict_probability_domain: bool,
422    ) -> Self {
423        let num_cols = schema.fields().len();
424        let row_dedup = RowDedupState::try_new(&schema);
425        Self {
426            rule_name,
427            facts: Vec::new(),
428            delta: Vec::new(),
429            schema,
430            key_column_indices,
431            all_column_indices: (0..num_cols).collect(),
432            facts_bytes: 0,
433            max_derived_bytes,
434            monotonic_agg,
435            row_dedup,
436            strict_probability_domain,
437        }
438    }
439
440    /// Reconcile the pre-computed schema with the actual physical plan output.
441    ///
442    /// `infer_expr_type` may guess wrong (e.g. `Property → Float64` for a
443    /// string column).  When the first real batch arrives with a different
444    /// schema, update ours so that `RowDedupState` / `RecordBatch::try_new`
445    /// use the correct types.
446    fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
447        if self.schema.fields() != actual_schema.fields() {
448            tracing::debug!(
449                rule = %self.rule_name,
450                "Reconciling fixpoint schema from physical plan output",
451            );
452            self.schema = Arc::clone(actual_schema);
453            self.row_dedup = RowDedupState::try_new(&self.schema);
454        }
455    }
456
457    /// Merge candidate rows into facts, computing delta (truly new rows).
458    ///
459    /// Returns `true` if any new facts were added.
460    pub async fn merge_delta(
461        &mut self,
462        candidates: Vec<RecordBatch>,
463        task_ctx: Option<Arc<TaskContext>>,
464    ) -> DFResult<bool> {
465        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
466            self.delta.clear();
467            return Ok(false);
468        }
469
470        // Reconcile schema from the first non-empty candidate batch.
471        // The physical plan's output types are authoritative over the
472        // planner's inferred types.
473        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
474            self.reconcile_schema(&first.schema());
475        }
476
477        // Round floats for stable dedup
478        let candidates = round_float_columns(&candidates);
479
480        // Compute delta: rows in candidates not already in facts
481        let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
482
483        if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
484            self.delta.clear();
485            // Update monotonic aggs even with empty delta (for stability check)
486            if let Some(ref mut agg) = self.monotonic_agg {
487                agg.snapshot();
488            }
489            return Ok(false);
490        }
491
492        // Check memory limit
493        let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
494        if self.facts_bytes + delta_bytes > self.max_derived_bytes {
495            return Err(datafusion::error::DataFusionError::Execution(
496                LocyRuntimeError::MemoryLimitExceeded {
497                    rule: self.rule_name.clone(),
498                    bytes: self.facts_bytes + delta_bytes,
499                    limit: self.max_derived_bytes,
500                }
501                .to_string(),
502            ));
503        }
504
505        // Update monotonic aggs
506        if let Some(ref mut agg) = self.monotonic_agg {
507            agg.snapshot();
508            agg.update(
509                &self.key_column_indices,
510                &delta,
511                self.strict_probability_domain,
512            )?;
513        }
514
515        // Append delta to facts
516        self.facts_bytes += delta_bytes;
517        self.facts.extend(delta.iter().cloned());
518        self.delta = delta;
519
520        Ok(true)
521    }
522
523    /// Dispatch to vectorized LeftAntiJoin, Arrow RowConverter dedup, or legacy ScalarKey dedup.
524    ///
525    /// Priority order:
526    /// 1. `arrow_left_anti_dedup` when `total_existing >= DEDUP_ANTI_JOIN_THRESHOLD` and task_ctx available.
527    /// 2. `RowDedupState` (persistent HashSet, O(M) per iteration) when schema is supported.
528    /// 3. `compute_delta_legacy` (rebuilds from facts, fallback for unsupported column types).
529    async fn compute_delta(
530        &mut self,
531        candidates: &[RecordBatch],
532        task_ctx: Option<&Arc<TaskContext>>,
533    ) -> DFResult<Vec<RecordBatch>> {
534        let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
535        if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
536            && let Some(ctx) = task_ctx
537        {
538            return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
539                .await;
540        }
541        if let Some(ref mut rd) = self.row_dedup {
542            rd.compute_delta(candidates, &self.schema)
543        } else {
544            self.compute_delta_legacy(candidates)
545        }
546    }
547
548    /// Legacy dedup: rebuild a `HashSet<Vec<ScalarKey>>` from all facts each call.
549    ///
550    /// Used as fallback when `RowConverter` does not support the schema's column types.
551    fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
552        // Build set of existing fact row keys (ALL columns)
553        let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
554        for batch in &self.facts {
555            for row_idx in 0..batch.num_rows() {
556                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
557                existing.insert(key);
558            }
559        }
560
561        let mut delta_batches = Vec::new();
562        for batch in candidates {
563            if batch.num_rows() == 0 {
564                continue;
565            }
566            // Filter to only new rows
567            let mut keep = Vec::with_capacity(batch.num_rows());
568            for row_idx in 0..batch.num_rows() {
569                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
570                keep.push(!existing.contains(&key));
571            }
572
573            // Also dedup within the candidate batch itself
574            for (row_idx, kept) in keep.iter_mut().enumerate() {
575                if *kept {
576                    let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
577                    if !existing.insert(key) {
578                        *kept = false;
579                    }
580                }
581            }
582
583            let keep_mask = arrow_array::BooleanArray::from(keep);
584            let new_rows = batch
585                .columns()
586                .iter()
587                .map(|col| {
588                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
589                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
590                    })
591                })
592                .collect::<DFResult<Vec<_>>>()?;
593
594            if new_rows.first().is_some_and(|c| !c.is_empty()) {
595                let filtered =
596                    RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
597                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
598                    })?;
599                delta_batches.push(filtered);
600            }
601        }
602
603        Ok(delta_batches)
604    }
605
606    /// Check if this rule has converged (no new facts and aggs stable).
607    pub fn is_converged(&self) -> bool {
608        let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
609        let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
610        delta_empty && agg_stable
611    }
612
613    /// Get all accumulated facts.
614    pub fn all_facts(&self) -> &[RecordBatch] {
615        &self.facts
616    }
617
618    /// Get the delta from the latest iteration.
619    pub fn all_delta(&self) -> &[RecordBatch] {
620        &self.delta
621    }
622
623    /// Consume self and return facts.
624    pub fn into_facts(self) -> Vec<RecordBatch> {
625        self.facts
626    }
627
628    /// Merge candidates using BEST BY semantics.
629    ///
630    /// Combines existing facts with new candidates, keeping only the best row
631    /// per KEY group according to `sort_criteria`. Returns `true` if the
632    /// best-per-KEY fact set actually changed (a genuinely better value was
633    /// found or a new KEY appeared).
634    ///
635    /// This replaces `merge_delta` for rules with BEST BY, enabling convergence
636    /// on cyclic graphs where dominated ALONG values would otherwise produce an
637    /// unbounded stream of "new" full-row facts.
638    pub fn merge_best_by(
639        &mut self,
640        candidates: Vec<RecordBatch>,
641        sort_criteria: &[SortCriterion],
642    ) -> DFResult<bool> {
643        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
644            self.delta.clear();
645            return Ok(false);
646        }
647
648        // Reconcile schema from the first non-empty candidate batch.
649        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
650            self.reconcile_schema(&first.schema());
651        }
652
653        // Round floats for stable dedup.
654        let candidates = round_float_columns(&candidates);
655
656        // Snapshot existing best-per-KEY facts for change detection.
657        let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
658            self.build_key_criteria_map(sort_criteria);
659
660        // Concat existing facts + new candidates.
661        let mut all_batches = self.facts.clone();
662        all_batches.extend(candidates);
663        let all_batches: Vec<_> = all_batches
664            .into_iter()
665            .filter(|b| b.num_rows() > 0)
666            .collect();
667        if all_batches.is_empty() {
668            self.delta.clear();
669            return Ok(false);
670        }
671
672        let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
673            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
674
675        if combined.num_rows() == 0 {
676            self.delta.clear();
677            return Ok(false);
678        }
679
680        // Sort by KEY ASC then criteria, so the best row per KEY group comes
681        // first.
682        let mut sort_columns = Vec::new();
683        for &ki in &self.key_column_indices {
684            sort_columns.push(arrow::compute::SortColumn {
685                values: Arc::clone(combined.column(ki)),
686                options: Some(arrow::compute::SortOptions {
687                    descending: false,
688                    nulls_first: false,
689                }),
690            });
691        }
692        for criterion in sort_criteria {
693            sort_columns.push(arrow::compute::SortColumn {
694                values: Arc::clone(combined.column(criterion.col_index)),
695                options: Some(arrow::compute::SortOptions {
696                    descending: !criterion.ascending,
697                    nulls_first: criterion.nulls_first,
698                }),
699            });
700        }
701
702        let sorted_indices =
703            arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
704        let sorted_columns: Vec<_> = combined
705            .columns()
706            .iter()
707            .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
708            .collect::<Result<Vec<_>, _>>()
709            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
710        let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
711            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
712
713        // Dedup: keep first (best) row per KEY group.
714        let mut keep_indices: Vec<u32> = Vec::new();
715        let mut prev_key: Option<Vec<ScalarKey>> = None;
716        for row_idx in 0..sorted.num_rows() {
717            let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
718            let is_new_group = match &prev_key {
719                None => true,
720                Some(prev) => *prev != key,
721            };
722            if is_new_group {
723                keep_indices.push(row_idx as u32);
724                prev_key = Some(key);
725            }
726        }
727
728        let keep_array = arrow_array::UInt32Array::from(keep_indices);
729        let output_columns: Vec<_> = sorted
730            .columns()
731            .iter()
732            .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
733            .collect::<Result<Vec<_>, _>>()
734            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
735        let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
736            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
737
738        // Detect whether the best-per-KEY set actually changed.
739        let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
740            let mut map = HashMap::new();
741            for row_idx in 0..pruned.num_rows() {
742                let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
743                let criteria: Vec<ScalarKey> = sort_criteria
744                    .iter()
745                    .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
746                    .collect();
747                map.insert(key, criteria);
748            }
749            map
750        };
751        let changed = old_best != new_best;
752
753        tracing::debug!(
754            rule = %self.rule_name,
755            old_keys = old_best.len(),
756            new_keys = new_best.len(),
757            changed = changed,
758            "BEST BY merge"
759        );
760
761        // Replace facts with the pruned set.
762        self.facts_bytes = batch_byte_size(&pruned);
763        self.facts = vec![pruned];
764        if changed {
765            // Delta is conceptually the new/improved facts, but since we
766            // replaced the entire set, just mark delta non-empty.
767            self.delta = self.facts.clone();
768        } else {
769            self.delta.clear();
770        }
771
772        // Rebuild row dedup from pruned facts for consistency.
773        self.row_dedup = RowDedupState::try_new(&self.schema);
774        if let Some(ref mut rd) = self.row_dedup {
775            rd.ingest_existing(&self.facts, &self.schema);
776        }
777
778        Ok(changed)
779    }
780
781    /// Build a map from KEY column values to sort criteria values.
782    fn build_key_criteria_map(
783        &self,
784        sort_criteria: &[SortCriterion],
785    ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
786        let mut map = HashMap::new();
787        for batch in &self.facts {
788            for row_idx in 0..batch.num_rows() {
789                let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
790                let criteria: Vec<ScalarKey> = sort_criteria
791                    .iter()
792                    .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
793                    .collect();
794                map.insert(key, criteria);
795            }
796        }
797        map
798    }
799}
800
801/// Estimate byte size of a RecordBatch.
802fn batch_byte_size(batch: &RecordBatch) -> usize {
803    batch
804        .columns()
805        .iter()
806        .map(|col| col.get_buffer_memory_size())
807        .sum()
808}
809
810// ---------------------------------------------------------------------------
811// Float rounding for stable dedup
812// ---------------------------------------------------------------------------
813
814/// Round all Float64 columns to 12 decimal places for stable dedup.
815fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
816    batches
817        .iter()
818        .map(|batch| {
819            let schema = batch.schema();
820            let has_float = schema
821                .fields()
822                .iter()
823                .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
824            if !has_float {
825                return batch.clone();
826            }
827
828            let columns: Vec<arrow_array::ArrayRef> = batch
829                .columns()
830                .iter()
831                .enumerate()
832                .map(|(i, col)| {
833                    if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
834                        let arr = col
835                            .as_any()
836                            .downcast_ref::<arrow_array::Float64Array>()
837                            .unwrap();
838                        let rounded: arrow_array::Float64Array = arr
839                            .iter()
840                            .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
841                            .collect();
842                        Arc::new(rounded) as arrow_array::ArrayRef
843                    } else {
844                        Arc::clone(col)
845                    }
846                })
847                .collect();
848
849            RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
850        })
851        .collect()
852}
853
854// ---------------------------------------------------------------------------
855// LeftAntiJoin delta deduplication
856// ---------------------------------------------------------------------------
857
858/// Row threshold above which the vectorized Arrow LeftAntiJoin dedup path is used.
859///
860/// Below this threshold the persistent `RowDedupState` HashSet is O(M) and
861/// avoids rebuilding the existing-row set; above it DataFusion's vectorized
862/// HashJoinExec is more cache-efficient.
863const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
864
865/// Deduplicate `candidates` against `existing` using DataFusion's HashJoinExec.
866///
867/// Returns rows in `candidates` that do not appear in `existing` (LeftAnti semantics).
868/// `null_equals_null = true` so NULLs are treated as equal for dedup purposes.
869async fn arrow_left_anti_dedup(
870    candidates: Vec<RecordBatch>,
871    existing: &[RecordBatch],
872    schema: &SchemaRef,
873    task_ctx: &Arc<TaskContext>,
874) -> DFResult<Vec<RecordBatch>> {
875    if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
876        return Ok(candidates);
877    }
878
879    let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
880    let right: Arc<dyn ExecutionPlan> =
881        Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
882
883    let on: Vec<(
884        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
885        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
886    )> = schema
887        .fields()
888        .iter()
889        .enumerate()
890        .map(|(i, field)| {
891            let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
892                datafusion::physical_plan::expressions::Column::new(field.name(), i),
893            );
894            let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
895                datafusion::physical_plan::expressions::Column::new(field.name(), i),
896            );
897            (l, r)
898        })
899        .collect();
900
901    if on.is_empty() {
902        return Ok(vec![]);
903    }
904
905    let join = HashJoinExec::try_new(
906        left,
907        right,
908        on,
909        None,
910        &JoinType::LeftAnti,
911        None,
912        PartitionMode::CollectLeft,
913        datafusion::common::NullEquality::NullEqualsNull,
914    )?;
915
916    let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
917    collect_all_partitions(&join_arc, task_ctx.clone()).await
918}
919
920// ---------------------------------------------------------------------------
921// Plan types for fixpoint rules
922// ---------------------------------------------------------------------------
923
924/// IS-ref binding: a reference from a clause body to a derived relation.
925#[derive(Debug, Clone)]
926pub struct IsRefBinding {
927    /// Index into the DerivedScanRegistry.
928    pub derived_scan_index: usize,
929    /// Name of the rule being referenced.
930    pub rule_name: String,
931    /// Whether this is a self-reference (rule references itself).
932    pub is_self_ref: bool,
933    /// Whether this is a negated reference (NOT IS).
934    pub negated: bool,
935    /// For negated IS-refs: `(left_body_col, right_derived_col)` pairs for anti-join filtering.
936    ///
937    /// `left_body_col` is the VID column in the clause body (e.g., `"n._vid"`);
938    /// `right_derived_col` is the corresponding KEY column in the negated rule's facts (e.g., `"n"`).
939    /// Empty for non-negated IS-refs.
940    pub anti_join_cols: Vec<(String, String)>,
941    /// Whether the target rule has a PROB column.
942    pub target_has_prob: bool,
943    /// Name of the PROB column in the target rule, if any.
944    pub target_prob_col: Option<String>,
945    /// `(body_col, derived_col)` pairs for provenance tracking.
946    ///
947    /// Used by shared-proof detection to find which source facts a derived row
948    /// consumed. Populated for all IS-refs (not just negated ones).
949    pub provenance_join_cols: Vec<(String, String)>,
950}
951
952/// A single clause (body) within a fixpoint rule.
953#[derive(Debug)]
954pub struct FixpointClausePlan {
955    /// The logical plan for the clause body.
956    pub body_logical: LogicalPlan,
957    /// IS-ref bindings used by this clause.
958    pub is_ref_bindings: Vec<IsRefBinding>,
959    /// Priority value for this clause (if PRIORITY semantics apply).
960    pub priority: Option<i64>,
961    /// ALONG binding variable names propagated from the planner.
962    pub along_bindings: Vec<String>,
963}
964
965/// Physical plan for a single rule in a fixpoint stratum.
966#[derive(Debug)]
967pub struct FixpointRulePlan {
968    /// Rule name.
969    pub name: String,
970    /// Clause bodies (each evaluates to candidate rows).
971    pub clauses: Vec<FixpointClausePlan>,
972    /// Output schema for this rule's derived relation.
973    pub yield_schema: SchemaRef,
974    /// Indices of KEY columns within yield_schema.
975    pub key_column_indices: Vec<usize>,
976    /// Priority value (if PRIORITY semantics apply).
977    pub priority: Option<i64>,
978    /// Whether this rule has FOLD semantics.
979    pub has_fold: bool,
980    /// FOLD bindings for post-fixpoint aggregation.
981    pub fold_bindings: Vec<FoldBinding>,
982    /// Whether this rule has BEST BY semantics.
983    pub has_best_by: bool,
984    /// BEST BY sort criteria for post-fixpoint selection.
985    pub best_by_criteria: Vec<SortCriterion>,
986    /// Whether this rule has PRIORITY semantics.
987    pub has_priority: bool,
988    /// Whether BEST BY should apply a deterministic secondary sort for
989    /// tie-breaking. When false, tied rows are selected non-deterministically
990    /// (faster but not repeatable across runs).
991    pub deterministic: bool,
992    /// Name of the PROB column in this rule's yield schema, if any.
993    pub prob_column_name: Option<String>,
994}
995
996// ---------------------------------------------------------------------------
997// run_fixpoint_loop — the core semi-naive iteration algorithm
998// ---------------------------------------------------------------------------
999
1000/// Run the semi-naive fixpoint iteration loop.
1001///
1002/// Evaluates all rules in a stratum repeatedly, feeding deltas back through
1003/// derived scan handles until convergence or limits are reached.
1004#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1005async fn run_fixpoint_loop(
1006    rules: Vec<FixpointRulePlan>,
1007    max_iterations: usize,
1008    timeout: Duration,
1009    graph_ctx: Arc<GraphExecutionContext>,
1010    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1011    storage: Arc<StorageManager>,
1012    schema_info: Arc<UniSchema>,
1013    params: HashMap<String, Value>,
1014    registry: Arc<DerivedScanRegistry>,
1015    output_schema: SchemaRef,
1016    max_derived_bytes: usize,
1017    derivation_tracker: Option<Arc<ProvenanceStore>>,
1018    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1019    strict_probability_domain: bool,
1020    probability_epsilon: f64,
1021    exact_probability: bool,
1022    max_bdd_variables: usize,
1023    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1024    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1025    top_k_proofs: usize,
1026) -> DFResult<Vec<RecordBatch>> {
1027    let start = Instant::now();
1028    let task_ctx = session_ctx.read().task_ctx();
1029
1030    // Initialize per-rule state
1031    let mut states: Vec<FixpointState> = rules
1032        .iter()
1033        .map(|rule| {
1034            let monotonic_agg = if !rule.fold_bindings.is_empty() {
1035                let bindings: Vec<MonotonicFoldBinding> = rule
1036                    .fold_bindings
1037                    .iter()
1038                    .map(|fb| MonotonicFoldBinding {
1039                        fold_name: fb.output_name.clone(),
1040                        kind: fb.kind.clone(),
1041                        input_col_index: fb.input_col_index,
1042                    })
1043                    .collect();
1044                Some(MonotonicAggState::new(bindings))
1045            } else {
1046                None
1047            };
1048            FixpointState::new(
1049                rule.name.clone(),
1050                Arc::clone(&rule.yield_schema),
1051                rule.key_column_indices.clone(),
1052                max_derived_bytes,
1053                monotonic_agg,
1054                strict_probability_domain,
1055            )
1056        })
1057        .collect();
1058
1059    // Main iteration loop
1060    let mut converged = false;
1061    let mut total_iters = 0usize;
1062    for iteration in 0..max_iterations {
1063        total_iters = iteration + 1;
1064        tracing::debug!("fixpoint iteration {}", iteration);
1065        let mut any_changed = false;
1066
1067        for rule_idx in 0..rules.len() {
1068            let rule = &rules[rule_idx];
1069
1070            // Update derived scan handles for this rule's clauses
1071            update_derived_scan_handles(&registry, &states, rule_idx, &rules);
1072
1073            // Evaluate clause bodies, tracking per-clause candidates for provenance.
1074            let mut all_candidates = Vec::new();
1075            let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1076            for clause in &rule.clauses {
1077                let mut batches = execute_subplan(
1078                    &clause.body_logical,
1079                    &params,
1080                    &HashMap::new(),
1081                    &graph_ctx,
1082                    &session_ctx,
1083                    &storage,
1084                    &schema_info,
1085                )
1086                .await?;
1087                // Apply negated IS-ref semantics: probabilistic complement or anti-join.
1088                for binding in &clause.is_ref_bindings {
1089                    if binding.negated
1090                        && !binding.anti_join_cols.is_empty()
1091                        && let Some(entry) = registry.get(binding.derived_scan_index)
1092                    {
1093                        let neg_facts = entry.data.read().clone();
1094                        if !neg_facts.is_empty() {
1095                            if binding.target_has_prob && rule.prob_column_name.is_some() {
1096                                // Probabilistic complement: add 1-p column instead of filtering.
1097                                let complement_col =
1098                                    format!("__prob_complement_{}", binding.rule_name);
1099                                if let Some(prob_col) = &binding.target_prob_col {
1100                                    batches = apply_prob_complement_composite(
1101                                        batches,
1102                                        &neg_facts,
1103                                        &binding.anti_join_cols,
1104                                        prob_col,
1105                                        &complement_col,
1106                                    )?;
1107                                } else {
1108                                    // target_has_prob but no prob_col: fall back to anti-join.
1109                                    batches = apply_anti_join_composite(
1110                                        batches,
1111                                        &neg_facts,
1112                                        &binding.anti_join_cols,
1113                                    )?;
1114                                }
1115                            } else {
1116                                // Boolean exclusion: anti-join (existing behavior)
1117                                batches = apply_anti_join_composite(
1118                                    batches,
1119                                    &neg_facts,
1120                                    &binding.anti_join_cols,
1121                                )?;
1122                            }
1123                        }
1124                    }
1125                }
1126                // Multiply complement columns into the PROB column (if any) and clean up
1127                let complement_cols: Vec<String> = if !batches.is_empty() {
1128                    batches[0]
1129                        .schema()
1130                        .fields()
1131                        .iter()
1132                        .filter(|f| f.name().starts_with("__prob_complement_"))
1133                        .map(|f| f.name().clone())
1134                        .collect()
1135                } else {
1136                    vec![]
1137                };
1138                if !complement_cols.is_empty() {
1139                    batches = multiply_prob_factors(
1140                        batches,
1141                        rule.prob_column_name.as_deref(),
1142                        &complement_cols,
1143                    )?;
1144                }
1145
1146                clause_candidates.push(batches.clone());
1147                all_candidates.extend(batches);
1148            }
1149
1150            // Merge candidates into facts.
1151            // For BEST BY rules, use a specialized merge that keeps only the
1152            // best row per KEY group, enabling convergence on cyclic graphs.
1153            let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1154                states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1155            } else {
1156                states[rule_idx]
1157                    .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1158                    .await?
1159            };
1160            if changed {
1161                any_changed = true;
1162                // Record provenance for newly derived facts when tracker is present.
1163                if let Some(ref tracker) = derivation_tracker {
1164                    record_provenance(
1165                        tracker,
1166                        rule,
1167                        &states[rule_idx],
1168                        &clause_candidates,
1169                        iteration,
1170                        &registry,
1171                        top_k_proofs,
1172                    );
1173                }
1174            }
1175        }
1176
1177        // Check convergence
1178        if !any_changed && states.iter().all(|s| s.is_converged()) {
1179            tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1180            converged = true;
1181            break;
1182        }
1183
1184        // Check timeout
1185        if start.elapsed() > timeout {
1186            return Err(datafusion::error::DataFusionError::Execution(
1187                LocyRuntimeError::NonConvergence {
1188                    iterations: iteration + 1,
1189                }
1190                .to_string(),
1191            ));
1192        }
1193    }
1194
1195    // Write per-rule iteration counts to the shared slot.
1196    if let Ok(mut counts) = iteration_counts.write() {
1197        for rule in &rules {
1198            counts.insert(rule.name.clone(), total_iters);
1199        }
1200    }
1201
1202    // If we exhausted all iterations without converging, return a non-convergence error.
1203    if !converged {
1204        return Err(datafusion::error::DataFusionError::Execution(
1205            LocyRuntimeError::NonConvergence {
1206                iterations: max_iterations,
1207            }
1208            .to_string(),
1209        ));
1210    }
1211
1212    // Post-fixpoint processing per rule and collect output
1213    let task_ctx = session_ctx.read().task_ctx();
1214    let mut all_output = Vec::new();
1215
1216    for (rule_idx, state) in states.into_iter().enumerate() {
1217        let rule = &rules[rule_idx];
1218        let mut facts = state.into_facts();
1219        if facts.is_empty() {
1220            continue;
1221        }
1222
1223        // Detect shared proofs before FOLD collapses groups.
1224        let shared_info = if let Some(ref tracker) = derivation_tracker {
1225            detect_shared_lineage(rule, &facts, tracker, &warnings_slot)
1226        } else {
1227            None
1228        };
1229
1230        // Apply BDD for shared groups if exact_probability is enabled.
1231        if exact_probability
1232            && let Some(ref info) = shared_info
1233            && let Some(ref tracker) = derivation_tracker
1234        {
1235            facts = apply_exact_wmc(
1236                facts,
1237                rule,
1238                info,
1239                tracker,
1240                max_bdd_variables,
1241                &warnings_slot,
1242                &approximate_slot,
1243            )?;
1244        }
1245
1246        let processed = apply_post_fixpoint_chain(
1247            facts,
1248            rule,
1249            &task_ctx,
1250            strict_probability_domain,
1251            probability_epsilon,
1252        )
1253        .await?;
1254        all_output.extend(processed);
1255    }
1256
1257    // If no output, return empty batch with output schema
1258    if all_output.is_empty() {
1259        all_output.push(RecordBatch::new_empty(output_schema));
1260    }
1261
1262    Ok(all_output)
1263}
1264
1265// ---------------------------------------------------------------------------
1266// Provenance recording helpers
1267// ---------------------------------------------------------------------------
1268
1269/// Record provenance for all newly derived facts (rows in the current delta).
1270///
1271/// Called after `merge_delta` returns `true`. Attributes each new fact to the
1272/// clause most likely to have produced it, using first-derivation-wins semantics.
1273fn record_provenance(
1274    tracker: &Arc<ProvenanceStore>,
1275    rule: &FixpointRulePlan,
1276    state: &FixpointState,
1277    clause_candidates: &[Vec<RecordBatch>],
1278    iteration: usize,
1279    registry: &Arc<DerivedScanRegistry>,
1280    top_k_proofs: usize,
1281) {
1282    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1283
1284    // Pre-compute base fact probabilities for top-k mode.
1285    let base_probs = if top_k_proofs > 0 {
1286        tracker.base_fact_probs()
1287    } else {
1288        HashMap::new()
1289    };
1290
1291    for delta_batch in state.all_delta() {
1292        for row_idx in 0..delta_batch.num_rows() {
1293            let row_hash = format!(
1294                "{:?}",
1295                extract_scalar_key(delta_batch, &all_indices, row_idx)
1296            )
1297            .into_bytes();
1298            let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1299            let clause_index =
1300                find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1301
1302            let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1303
1304            let proof_probability = if top_k_proofs > 0 {
1305                compute_proof_probability(&support, &base_probs)
1306            } else {
1307                None
1308            };
1309
1310            let entry = ProvenanceAnnotation {
1311                rule_name: rule.name.clone(),
1312                clause_index,
1313                support,
1314                along_values: {
1315                    let along_names: Vec<String> = rule
1316                        .clauses
1317                        .get(clause_index)
1318                        .map(|c| c.along_bindings.clone())
1319                        .unwrap_or_default();
1320                    along_names
1321                        .iter()
1322                        .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1323                        .collect()
1324                },
1325                iteration,
1326                fact_row,
1327                proof_probability,
1328            };
1329            if top_k_proofs > 0 {
1330                tracker.record_top_k(row_hash, entry, top_k_proofs);
1331            } else {
1332                tracker.record(row_hash, entry);
1333            }
1334        }
1335    }
1336}
1337
1338/// Collect IS-ref input facts for a derived row using provenance join columns.
1339///
1340/// For each non-negated IS-ref binding in the clause, extracts body-side key
1341/// values from the delta row and finds matching source rows in the registry.
1342/// Returns a `ProofTerm` for each match (with the source fact hash).
1343fn collect_is_ref_inputs(
1344    rule: &FixpointRulePlan,
1345    clause_index: usize,
1346    delta_batch: &RecordBatch,
1347    row_idx: usize,
1348    registry: &Arc<DerivedScanRegistry>,
1349) -> Vec<ProofTerm> {
1350    let clause = match rule.clauses.get(clause_index) {
1351        Some(c) => c,
1352        None => return vec![],
1353    };
1354
1355    let mut inputs = Vec::new();
1356    let delta_schema = delta_batch.schema();
1357
1358    for binding in &clause.is_ref_bindings {
1359        if binding.negated {
1360            continue;
1361        }
1362        if binding.provenance_join_cols.is_empty() {
1363            continue;
1364        }
1365
1366        // Extract body-side values from the delta row for each provenance join col.
1367        let body_values: Vec<(String, ScalarKey)> = binding
1368            .provenance_join_cols
1369            .iter()
1370            .filter_map(|(body_col, _derived_col)| {
1371                let col_idx = delta_schema
1372                    .fields()
1373                    .iter()
1374                    .position(|f| f.name() == body_col)?;
1375                let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
1376                Some((body_col.clone(), key.into_iter().next()?))
1377            })
1378            .collect();
1379
1380        if body_values.len() != binding.provenance_join_cols.len() {
1381            continue;
1382        }
1383
1384        // Read current data from the registry entry for this IS-ref's rule.
1385        let entry = match registry.get(binding.derived_scan_index) {
1386            Some(e) => e,
1387            None => continue,
1388        };
1389        let source_batches = entry.data.read();
1390        let source_schema = &entry.schema;
1391
1392        // Find matching source rows and hash them.
1393        for src_batch in source_batches.iter() {
1394            let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
1395            for src_row in 0..src_batch.num_rows() {
1396                let matches = binding.provenance_join_cols.iter().enumerate().all(
1397                    |(i, (_body_col, derived_col))| {
1398                        let src_col_idx = source_schema
1399                            .fields()
1400                            .iter()
1401                            .position(|f| f.name() == derived_col);
1402                        match src_col_idx {
1403                            Some(idx) => {
1404                                let src_key = extract_scalar_key(src_batch, &[idx], src_row);
1405                                src_key.first() == Some(&body_values[i].1)
1406                            }
1407                            None => false,
1408                        }
1409                    },
1410                );
1411                if matches {
1412                    let fact_hash = format!(
1413                        "{:?}",
1414                        extract_scalar_key(src_batch, &all_src_indices, src_row)
1415                    )
1416                    .into_bytes();
1417                    inputs.push(ProofTerm {
1418                        source_rule: binding.rule_name.clone(),
1419                        base_fact_id: fact_hash,
1420                    });
1421                }
1422            }
1423        }
1424    }
1425
1426    inputs
1427}
1428
1429// ---------------------------------------------------------------------------
1430// Shared-lineage detection
1431// ---------------------------------------------------------------------------
1432
1433/// Detect KEY groups in a rule's pre-fold facts where recursive derivation
1434/// may violate the independence assumption of MNOR/MPROD.
1435///
1436/// Uses a two-tier strategy:
1437/// 1. **Precise**: If the `ProvenanceStore` has populated `support` for facts
1438///    in the group, we recursively compute lineage (Cui & Widom 2000) and
1439///    check for pairwise overlap. A shared base fact proves a dependency.
1440/// 2. **Structural fallback**: When lineage tracking is unavailable (e.g., the
1441///    IS-ref subject variables were projected away), we check whether any fact
1442///    in a multi-row group was derived by a clause that has IS-ref bindings.
1443///    Recursive derivation through shared relations is a strong signal that
1444///    proof paths may share intermediate nodes.
1445///
1446/// Per-row data collected during shared-lineage detection.
1447#[expect(
1448    dead_code,
1449    reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
1450)]
1451pub(crate) struct SharedGroupRow {
1452    pub fact_hash: Vec<u8>,
1453    pub lineage: HashSet<Vec<u8>>,
1454}
1455
1456/// Information about groups with shared proofs, returned by `detect_shared_lineage`.
1457pub(crate) struct SharedLineageInfo {
1458    /// KEY group → rows with their base fact sets.
1459    pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
1460}
1461
1462/// Build a byte key that uniquely identifies a row across all columns.
1463fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
1464    format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
1465}
1466
1467/// Emits at most one `SharedProbabilisticDependency` warning per rule.
1468/// Returns `Some(SharedLineageInfo)` if any group has shared proofs.
1469fn detect_shared_lineage(
1470    rule: &FixpointRulePlan,
1471    pre_fold_facts: &[RecordBatch],
1472    tracker: &Arc<ProvenanceStore>,
1473    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1474) -> Option<SharedLineageInfo> {
1475    use crate::query::df_graph::locy_fold::FoldAggKind;
1476    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1477
1478    // Only check rules with MNOR/MPROD fold bindings.
1479    let has_prob_fold = rule
1480        .fold_bindings
1481        .iter()
1482        .any(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1483    if !has_prob_fold {
1484        return None;
1485    }
1486
1487    // Group facts by KEY columns.
1488    let key_indices = &rule.key_column_indices;
1489    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1490
1491    let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
1492    for batch in pre_fold_facts {
1493        for row_idx in 0..batch.num_rows() {
1494            let key = extract_scalar_key(batch, key_indices, row_idx);
1495            let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1496            groups.entry(key).or_default().push(fact_hash);
1497        }
1498    }
1499
1500    let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
1501    let mut any_shared = false;
1502
1503    // Check each group with ≥2 rows.
1504    for (key, fact_hashes) in &groups {
1505        if fact_hashes.len() < 2 {
1506            continue;
1507        }
1508
1509        // Tier 1: precise base-fact overlap detection via tracker inputs.
1510        let mut has_inputs = false;
1511        let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
1512        for fh in fact_hashes {
1513            let bases = compute_lineage(fh, tracker, &mut HashSet::new());
1514            if let Some(entry) = tracker.lookup(fh)
1515                && !entry.support.is_empty()
1516            {
1517                has_inputs = true;
1518            }
1519            per_row_bases.push(bases);
1520        }
1521
1522        let shared_found = if has_inputs {
1523            // At least some facts have tracked inputs — do precise comparison.
1524            let mut found = false;
1525            'outer: for i in 0..per_row_bases.len() {
1526                for j in (i + 1)..per_row_bases.len() {
1527                    if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
1528                        found = true;
1529                        break 'outer;
1530                    }
1531                }
1532            }
1533            found
1534        } else {
1535            // Tier 2: structural fallback — check if any fact in the group was
1536            // derived by a clause with IS-ref bindings (recursive derivation).
1537            fact_hashes.iter().any(|fh| {
1538                tracker.lookup(fh).is_some_and(|entry| {
1539                    rule.clauses
1540                        .get(entry.clause_index)
1541                        .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
1542                })
1543            })
1544        };
1545
1546        if shared_found {
1547            any_shared = true;
1548            // Collect the group rows with their base facts for BDD use.
1549            let rows: Vec<SharedGroupRow> = fact_hashes
1550                .iter()
1551                .zip(per_row_bases.into_iter())
1552                .map(|(fh, bases)| SharedGroupRow {
1553                    fact_hash: fh.clone(),
1554                    lineage: bases,
1555                })
1556                .collect();
1557            shared_groups.insert(key.clone(), rows);
1558        }
1559    }
1560
1561    // Phase 5: Cross-group correlation warning.
1562    // Check if any IS-ref input fact appears in multiple KEY groups.
1563    // This is independent of within-group sharing: even rules whose KEY groups
1564    // each have only one post-fold row can exhibit cross-group correlation when
1565    // different groups consume the same IS-ref base fact.
1566    {
1567        let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
1568        for (key, fact_hashes) in &groups {
1569            for fh in fact_hashes {
1570                if let Some(entry) = tracker.lookup(fh) {
1571                    for input in &entry.support {
1572                        input_to_groups
1573                            .entry(input.base_fact_id.clone())
1574                            .or_default()
1575                            .insert(key.clone());
1576                    }
1577                }
1578            }
1579        }
1580        let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
1581        if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
1582            let already_warned = warnings.iter().any(|w| {
1583                w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
1584                    && w.rule_name == rule.name
1585            });
1586            if !already_warned {
1587                warnings.push(RuntimeWarning {
1588                    code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
1589                    message: format!(
1590                        "Rule '{}': IS-ref base facts are shared across different KEY \
1591                         groups. BDD corrects per-group probabilities but cannot account \
1592                         for cross-group correlations.",
1593                        rule.name
1594                    ),
1595                    rule_name: rule.name.clone(),
1596                    variable_count: None,
1597                    key_group: None,
1598                });
1599            }
1600        }
1601    }
1602
1603    if any_shared {
1604        if let Ok(mut warnings) = warnings_slot.write() {
1605            let already_warned = warnings.iter().any(|w| {
1606                w.code == RuntimeWarningCode::SharedProbabilisticDependency
1607                    && w.rule_name == rule.name
1608            });
1609            if !already_warned {
1610                warnings.push(RuntimeWarning {
1611                    code: RuntimeWarningCode::SharedProbabilisticDependency,
1612                    message: format!(
1613                        "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
1614                         share intermediate facts, violating the independence assumption. \
1615                         Results may overestimate probability.",
1616                        rule.name
1617                    ),
1618                    rule_name: rule.name.clone(),
1619                    variable_count: None,
1620                    key_group: None,
1621                });
1622            }
1623        }
1624        Some(SharedLineageInfo { shared_groups })
1625    } else {
1626        None
1627    }
1628}
1629
1630/// Record provenance and detect shared proofs for non-recursive strata.
1631///
1632/// Non-recursive rules are evaluated in a single pass (no fixpoint loop), so
1633/// the regular `record_provenance` + `detect_shared_lineage` path is never hit.
1634/// This function bridges that gap by recording a `ProvenanceAnnotation` for every
1635/// fact produced by each clause and then running the same two-tier detection
1636/// logic used by the recursive path.
1637pub(crate) fn record_and_detect_lineage_nonrecursive(
1638    rule: &FixpointRulePlan,
1639    tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
1640    tracker: &Arc<ProvenanceStore>,
1641    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1642    registry: &Arc<DerivedScanRegistry>,
1643    top_k_proofs: usize,
1644) -> Option<SharedLineageInfo> {
1645    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1646
1647    // Pre-compute base fact probabilities for top-k mode.
1648    let base_probs = if top_k_proofs > 0 {
1649        tracker.base_fact_probs()
1650    } else {
1651        HashMap::new()
1652    };
1653
1654    // Record provenance for each clause's facts.
1655    for (clause_index, batches) in tagged_clause_facts {
1656        for batch in batches {
1657            for row_idx in 0..batch.num_rows() {
1658                let row_hash = fact_hash_key(batch, &all_indices, row_idx);
1659                let fact_row = batch_row_to_value_map(batch, row_idx);
1660
1661                let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
1662
1663                let proof_probability = if top_k_proofs > 0 {
1664                    compute_proof_probability(&support, &base_probs)
1665                } else {
1666                    None
1667                };
1668
1669                let entry = ProvenanceAnnotation {
1670                    rule_name: rule.name.clone(),
1671                    clause_index: *clause_index,
1672                    support,
1673                    along_values: {
1674                        let along_names: Vec<String> = rule
1675                            .clauses
1676                            .get(*clause_index)
1677                            .map(|c| c.along_bindings.clone())
1678                            .unwrap_or_default();
1679                        along_names
1680                            .iter()
1681                            .filter_map(|name| {
1682                                fact_row.get(name).map(|v| (name.clone(), v.clone()))
1683                            })
1684                            .collect()
1685                    },
1686                    iteration: 0,
1687                    fact_row,
1688                    proof_probability,
1689                };
1690                if top_k_proofs > 0 {
1691                    tracker.record_top_k(row_hash, entry, top_k_proofs);
1692                } else {
1693                    tracker.record(row_hash, entry);
1694                }
1695            }
1696        }
1697    }
1698
1699    // Flatten all clause facts and run detection.
1700    let all_facts: Vec<RecordBatch> = tagged_clause_facts
1701        .iter()
1702        .flat_map(|(_, batches)| batches.iter().cloned())
1703        .collect();
1704    detect_shared_lineage(rule, &all_facts, tracker, warnings_slot)
1705}
1706
1707/// Apply exact weighted model counting (WMC) for shared-lineage groups.
1708///
1709/// Replaces multiple rows in groups with shared lineage with a single
1710/// representative row whose PROB column carries the BDD-computed exact
1711/// probability (Sang et al. 2005). For groups that exceed
1712/// `max_bdd_variables`, rows are left unchanged and a `BddLimitExceeded`
1713/// warning is emitted.
1714pub(crate) fn apply_exact_wmc(
1715    pre_fold_facts: Vec<RecordBatch>,
1716    rule: &FixpointRulePlan,
1717    shared_info: &SharedLineageInfo,
1718    tracker: &Arc<ProvenanceStore>,
1719    max_bdd_variables: usize,
1720    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1721    approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1722) -> DFResult<Vec<RecordBatch>> {
1723    use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
1724    use crate::query::df_graph::locy_fold::FoldAggKind;
1725    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1726
1727    // Find the MNOR/MPROD fold binding to know which column to overwrite.
1728    let prob_fold = rule
1729        .fold_bindings
1730        .iter()
1731        .find(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1732    let prob_fold = match prob_fold {
1733        Some(f) => f,
1734        None => return Ok(pre_fold_facts),
1735    };
1736    let semiring_op = if matches!(prob_fold.kind, FoldAggKind::Nor) {
1737        SemiringOp::Disjunction
1738    } else {
1739        SemiringOp::Conjunction
1740    };
1741    let prob_col_idx = prob_fold.input_col_index;
1742    let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
1743
1744    let key_indices = &rule.key_column_indices;
1745    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1746
1747    // Build a set of shared group keys for quick lookup.
1748    let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
1749
1750    // Phase 1: Collect all rows for each shared KEY group across all batches.
1751    // Store (batch_index, row_index) pairs for each group.
1752    struct GroupAccum {
1753        base_facts: Vec<HashSet<Vec<u8>>>,
1754        base_probs: HashMap<Vec<u8>, f64>,
1755        /// First occurrence: (batch_index, row_index) — used as representative.
1756        representative: (usize, usize),
1757        row_locations: Vec<(usize, usize)>,
1758    }
1759
1760    let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
1761    let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); // (batch_idx, row_idx)
1762
1763    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1764        for row_idx in 0..batch.num_rows() {
1765            let key = extract_scalar_key(batch, key_indices, row_idx);
1766            if shared_keys.contains(&key) {
1767                let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1768                let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
1769
1770                let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
1771                    base_facts: Vec::new(),
1772                    base_probs: HashMap::new(),
1773                    representative: (batch_idx, row_idx),
1774                    row_locations: Vec::new(),
1775                });
1776
1777                // Look up probabilities for base facts.
1778                for bf in &bases {
1779                    if !accum.base_probs.contains_key(bf)
1780                        && let Some(entry) = tracker.lookup(bf)
1781                        && let Some(val) = entry.fact_row.get(&prob_col_name)
1782                        && let Some(p) = value_to_f64(val)
1783                    {
1784                        accum.base_probs.insert(bf.clone(), p);
1785                    }
1786                }
1787
1788                accum.base_facts.push(bases);
1789                accum.row_locations.push((batch_idx, row_idx));
1790            } else {
1791                non_shared_rows.push((batch_idx, row_idx));
1792            }
1793        }
1794    }
1795
1796    // Phase 2: Compute BDD for each shared group (across all batches).
1797    // Track which (batch_idx, row_idx) pairs to keep vs drop.
1798    let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
1799    // Map of (batch_idx, row_idx) → overridden PROB value (for BDD-succeeded groups).
1800    let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
1801
1802    // All non-shared rows are kept.
1803    for &loc in &non_shared_rows {
1804        keep_rows.insert(loc);
1805    }
1806
1807    for (key, accum) in &group_accums {
1808        let bdd_result = weighted_model_count(
1809            &accum.base_facts,
1810            &accum.base_probs,
1811            semiring_op,
1812            max_bdd_variables,
1813        );
1814
1815        if bdd_result.approximated {
1816            // Emit BddLimitExceeded warning (one per key group).
1817            if let Ok(mut warnings) = warnings_slot.write() {
1818                let key_desc = format!("{key:?}");
1819                let already_warned = warnings.iter().any(|w| {
1820                    w.code == RuntimeWarningCode::BddLimitExceeded
1821                        && w.rule_name == rule.name
1822                        && w.key_group.as_deref() == Some(&key_desc)
1823                });
1824                if !already_warned {
1825                    warnings.push(RuntimeWarning {
1826                        code: RuntimeWarningCode::BddLimitExceeded,
1827                        message: format!(
1828                            "Rule '{}': BDD variable limit exceeded ({} > {}). \
1829                             Falling back to independence-mode result.",
1830                            rule.name, bdd_result.variable_count, max_bdd_variables
1831                        ),
1832                        rule_name: rule.name.clone(),
1833                        variable_count: Some(bdd_result.variable_count),
1834                        key_group: Some(key_desc),
1835                    });
1836                }
1837            }
1838            if let Ok(mut approx) = approximate_slot.write() {
1839                let key_desc = format!("{key:?}");
1840                approx.entry(rule.name.clone()).or_default().push(key_desc);
1841            }
1842            // Keep all rows unchanged.
1843            for &loc in &accum.row_locations {
1844                keep_rows.insert(loc);
1845            }
1846        } else {
1847            // BDD succeeded: keep one representative row with overridden PROB.
1848            keep_rows.insert(accum.representative);
1849            overrides.insert(accum.representative, bdd_result.probability);
1850        }
1851    }
1852
1853    // Phase 3: Build output batches by filtering kept rows per batch.
1854    let mut result_batches = Vec::new();
1855    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1856        let kept_indices: Vec<usize> = (0..batch.num_rows())
1857            .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
1858            .collect();
1859
1860        if kept_indices.is_empty() {
1861            continue;
1862        }
1863
1864        let indices = arrow::array::UInt32Array::from(
1865            kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
1866        );
1867        let mut columns: Vec<arrow::array::ArrayRef> = batch
1868            .columns()
1869            .iter()
1870            .map(|col| arrow::compute::take(col, &indices, None))
1871            .collect::<Result<Vec<_>, _>>()
1872            .map_err(arrow_err)?;
1873
1874        // Check if any kept rows have PROB overrides.
1875        let override_map: Vec<Option<f64>> = kept_indices
1876            .iter()
1877            .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
1878            .collect();
1879
1880        if override_map.iter().any(|o| o.is_some()) {
1881            // Rebuild the PROB column with overrides.
1882            let existing_prob = columns[prob_col_idx]
1883                .as_any()
1884                .downcast_ref::<arrow::array::Float64Array>();
1885            let new_values: Vec<f64> = override_map
1886                .iter()
1887                .enumerate()
1888                .map(|(i, ov)| match ov {
1889                    Some(p) => *p,
1890                    None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
1891                })
1892                .collect();
1893            columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
1894        }
1895
1896        let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
1897        result_batches.push(result_batch);
1898    }
1899
1900    Ok(result_batches)
1901}
1902
1903/// Extract an f64 from a `Value`, supporting Float and Int.
1904fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
1905    match val {
1906        uni_common::Value::Float(f) => Some(*f),
1907        uni_common::Value::Int(i) => Some(*i as f64),
1908        _ => None,
1909    }
1910}
1911
1912/// Compute the lineage of a derived fact (Cui & Widom 2000).
1913///
1914/// Recursively traverses the provenance store to collect the set of base-level
1915/// fact hashes that contribute to this derivation. A base fact is one with no
1916/// IS-ref support (a graph-level fact). Intermediate facts are expanded
1917/// transitively through the store.
1918fn compute_lineage(
1919    fact_hash: &[u8],
1920    tracker: &Arc<ProvenanceStore>,
1921    visited: &mut HashSet<Vec<u8>>,
1922) -> HashSet<Vec<u8>> {
1923    if !visited.insert(fact_hash.to_vec()) {
1924        return HashSet::new(); // Cycle guard.
1925    }
1926
1927    match tracker.lookup(fact_hash) {
1928        Some(entry) if !entry.support.is_empty() => {
1929            let mut bases = HashSet::new();
1930            for input in &entry.support {
1931                let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
1932                bases.extend(child_bases);
1933            }
1934            bases
1935        }
1936        _ => {
1937            // Base fact (no tracker entry or no inputs).
1938            let mut set = HashSet::new();
1939            set.insert(fact_hash.to_vec());
1940            set
1941        }
1942    }
1943}
1944
1945/// Determine which clause produced a given row by checking each clause's candidates.
1946///
1947/// Returns the index of the first clause whose candidates contain a matching row.
1948/// Falls back to 0 if no match is found.
1949fn find_clause_for_row(
1950    delta_batch: &RecordBatch,
1951    row_idx: usize,
1952    all_indices: &[usize],
1953    clause_candidates: &[Vec<RecordBatch>],
1954) -> usize {
1955    let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
1956    for (clause_idx, batches) in clause_candidates.iter().enumerate() {
1957        for batch in batches {
1958            if batch.num_columns() != all_indices.len() {
1959                continue;
1960            }
1961            for r in 0..batch.num_rows() {
1962                if extract_scalar_key(batch, all_indices, r) == target_key {
1963                    return clause_idx;
1964                }
1965            }
1966        }
1967    }
1968    0
1969}
1970
1971/// Convert a single row from a `RecordBatch` at `row_idx` into a `HashMap<String, Value>`.
1972fn batch_row_to_value_map(
1973    batch: &RecordBatch,
1974    row_idx: usize,
1975) -> std::collections::HashMap<String, Value> {
1976    use uni_store::storage::arrow_convert::arrow_to_value;
1977
1978    let schema = batch.schema();
1979    schema
1980        .fields()
1981        .iter()
1982        .enumerate()
1983        .map(|(col_idx, field)| {
1984            let col = batch.column(col_idx);
1985            let val = arrow_to_value(col.as_ref(), row_idx, None);
1986            (field.name().clone(), val)
1987        })
1988        .collect()
1989}
1990
1991/// Filter `batches` to exclude rows where `left_col` VID appears in `neg_facts[right_col]`.
1992///
1993/// Implements anti-join semantics for negated IS-refs (`n IS NOT rule`): keeps only
1994/// rows whose subject VID is NOT present in the negated rule's fully-converged facts.
1995pub fn apply_anti_join(
1996    batches: Vec<RecordBatch>,
1997    neg_facts: &[RecordBatch],
1998    left_col: &str,
1999    right_col: &str,
2000) -> datafusion::error::Result<Vec<RecordBatch>> {
2001    use arrow::compute::filter_record_batch;
2002    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2003
2004    // Collect right-side VIDs from the negated rule's derived facts.
2005    let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2006    for batch in neg_facts {
2007        let Ok(idx) = batch.schema().index_of(right_col) else {
2008            continue;
2009        };
2010        let arr = batch.column(idx);
2011        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2012            continue;
2013        };
2014        for i in 0..vids.len() {
2015            if !vids.is_null(i) {
2016                banned.insert(vids.value(i));
2017            }
2018        }
2019    }
2020
2021    if banned.is_empty() {
2022        return Ok(batches);
2023    }
2024
2025    // Filter body batches: keep rows where left_col NOT IN banned.
2026    let mut result = Vec::new();
2027    for batch in batches {
2028        let Ok(idx) = batch.schema().index_of(left_col) else {
2029            result.push(batch);
2030            continue;
2031        };
2032        let arr = batch.column(idx);
2033        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2034            result.push(batch);
2035            continue;
2036        };
2037        let keep: Vec<bool> = (0..vids.len())
2038            .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2039            .collect();
2040        let keep_arr = BooleanArray::from(keep);
2041        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2042        if filtered.num_rows() > 0 {
2043            result.push(filtered);
2044        }
2045    }
2046    Ok(result)
2047}
2048
2049/// Probabilistic complement for negated IS-refs targeting PROB rules.
2050///
2051/// Instead of filtering out matching VIDs (anti-join), this adds a complement column
2052/// `__prob_complement_{rule_name}` with value `1 - p` for each matching VID, and `1.0`
2053/// for VIDs not present in the negated rule's facts.
2054///
2055/// This implements the probabilistic complement semantics: `IS NOT risk` on a PROB rule
2056/// yields the probability that the entity is NOT risky.
2057pub fn apply_prob_complement(
2058    batches: Vec<RecordBatch>,
2059    neg_facts: &[RecordBatch],
2060    left_col: &str,
2061    right_col: &str,
2062    prob_col: &str,
2063    complement_col_name: &str,
2064) -> datafusion::error::Result<Vec<RecordBatch>> {
2065    use arrow_array::{Array as _, Float64Array, UInt64Array};
2066
2067    // Build VID → probability lookup from negative facts
2068    let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
2069    for batch in neg_facts {
2070        let Ok(vid_idx) = batch.schema().index_of(right_col) else {
2071            continue;
2072        };
2073        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2074            continue;
2075        };
2076        let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
2077            continue;
2078        };
2079        let prob_arr = batch.column(prob_idx);
2080        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2081        for i in 0..vids.len() {
2082            if !vids.is_null(i) {
2083                let p = probs
2084                    .and_then(|arr| {
2085                        if arr.is_null(i) {
2086                            None
2087                        } else {
2088                            Some(arr.value(i))
2089                        }
2090                    })
2091                    .unwrap_or(0.0);
2092                // If multiple facts for same VID, use noisy-OR combination:
2093                // combined = 1 - (1 - existing) * (1 - new)
2094                prob_map
2095                    .entry(vids.value(i))
2096                    .and_modify(|existing| {
2097                        *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2098                    })
2099                    .or_insert(p);
2100            }
2101        }
2102    }
2103
2104    // Add complement column to each batch
2105    let mut result = Vec::new();
2106    for batch in batches {
2107        let Ok(idx) = batch.schema().index_of(left_col) else {
2108            result.push(batch);
2109            continue;
2110        };
2111        let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
2112            result.push(batch);
2113            continue;
2114        };
2115
2116        // Compute complement values: 1 - p for matched VIDs, 1.0 for absent
2117        let complements: Vec<f64> = (0..vids.len())
2118            .map(|i| {
2119                if vids.is_null(i) {
2120                    1.0
2121                } else {
2122                    let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
2123                    1.0 - p
2124                }
2125            })
2126            .collect();
2127
2128        let complement_arr = Float64Array::from(complements);
2129
2130        // Add the complement column to the batch
2131        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2132        columns.push(std::sync::Arc::new(complement_arr));
2133
2134        let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
2135            batch.schema().fields().iter().cloned().collect();
2136        fields.push(std::sync::Arc::new(arrow_schema::Field::new(
2137            complement_col_name,
2138            arrow_schema::DataType::Float64,
2139            true,
2140        )));
2141
2142        let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2143        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2144        result.push(new_batch);
2145    }
2146    Ok(result)
2147}
2148
2149/// Probabilistic complement for composite (multi-column) join keys.
2150///
2151/// Builds a composite key from all `join_cols` right-side columns in
2152/// `neg_facts`, maps each composite key to a probability via noisy-OR
2153/// combination, then adds a single `complement_col_name` column with
2154/// `1 - p` for matched keys and `1.0` for absent keys.
2155pub fn apply_prob_complement_composite(
2156    batches: Vec<RecordBatch>,
2157    neg_facts: &[RecordBatch],
2158    join_cols: &[(String, String)],
2159    prob_col: &str,
2160    complement_col_name: &str,
2161) -> datafusion::error::Result<Vec<RecordBatch>> {
2162    use arrow_array::{Array as _, Float64Array, UInt64Array};
2163
2164    // Build composite-key → probability lookup from negative facts.
2165    let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
2166    for batch in neg_facts {
2167        let right_indices: Vec<usize> = join_cols
2168            .iter()
2169            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2170            .collect();
2171        if right_indices.len() != join_cols.len() {
2172            continue;
2173        }
2174        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2175            continue;
2176        };
2177        let prob_arr = batch.column(prob_idx);
2178        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2179        for row in 0..batch.num_rows() {
2180            let mut key = Vec::with_capacity(right_indices.len());
2181            let mut valid = true;
2182            for &ci in &right_indices {
2183                let col = batch.column(ci);
2184                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2185                    if vids.is_null(row) {
2186                        valid = false;
2187                        break;
2188                    }
2189                    key.push(vids.value(row));
2190                } else {
2191                    valid = false;
2192                    break;
2193                }
2194            }
2195            if !valid {
2196                continue;
2197            }
2198            let p = probs
2199                .and_then(|arr| {
2200                    if arr.is_null(row) {
2201                        None
2202                    } else {
2203                        Some(arr.value(row))
2204                    }
2205                })
2206                .unwrap_or(0.0);
2207            // Noisy-OR combination for duplicate composite keys.
2208            prob_map
2209                .entry(key)
2210                .and_modify(|existing| {
2211                    *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2212                })
2213                .or_insert(p);
2214        }
2215    }
2216
2217    // Add complement column to each batch.
2218    let mut result = Vec::new();
2219    for batch in batches {
2220        let left_indices: Vec<usize> = join_cols
2221            .iter()
2222            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2223            .collect();
2224        if left_indices.len() != join_cols.len() {
2225            result.push(batch);
2226            continue;
2227        }
2228        let all_u64 = left_indices.iter().all(|&ci| {
2229            batch
2230                .column(ci)
2231                .as_any()
2232                .downcast_ref::<UInt64Array>()
2233                .is_some()
2234        });
2235        if !all_u64 {
2236            result.push(batch);
2237            continue;
2238        }
2239
2240        let complements: Vec<f64> = (0..batch.num_rows())
2241            .map(|row| {
2242                let mut key = Vec::with_capacity(left_indices.len());
2243                for &ci in &left_indices {
2244                    let vids = batch
2245                        .column(ci)
2246                        .as_any()
2247                        .downcast_ref::<UInt64Array>()
2248                        .unwrap();
2249                    if vids.is_null(row) {
2250                        return 1.0;
2251                    }
2252                    key.push(vids.value(row));
2253                }
2254                let p = prob_map.get(&key).copied().unwrap_or(0.0);
2255                1.0 - p
2256            })
2257            .collect();
2258
2259        let complement_arr = Float64Array::from(complements);
2260        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2261        columns.push(Arc::new(complement_arr));
2262
2263        let mut fields: Vec<Arc<arrow_schema::Field>> =
2264            batch.schema().fields().iter().cloned().collect();
2265        fields.push(Arc::new(arrow_schema::Field::new(
2266            complement_col_name,
2267            arrow_schema::DataType::Float64,
2268            true,
2269        )));
2270
2271        let new_schema = Arc::new(arrow_schema::Schema::new(fields));
2272        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2273        result.push(new_batch);
2274    }
2275    Ok(result)
2276}
2277
2278/// Boolean anti-join for composite (multi-column) join keys.
2279///
2280/// Builds a `HashSet<Vec<u64>>` from `neg_facts` using all right-side
2281/// columns in `join_cols`, then filters `batches` to keep only rows
2282/// whose composite left-side key is NOT in the set.
2283pub fn apply_anti_join_composite(
2284    batches: Vec<RecordBatch>,
2285    neg_facts: &[RecordBatch],
2286    join_cols: &[(String, String)],
2287) -> datafusion::error::Result<Vec<RecordBatch>> {
2288    use arrow::compute::filter_record_batch;
2289    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2290
2291    // Collect composite keys from the negated rule's derived facts.
2292    let mut banned: HashSet<Vec<u64>> = HashSet::new();
2293    for batch in neg_facts {
2294        let right_indices: Vec<usize> = join_cols
2295            .iter()
2296            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2297            .collect();
2298        if right_indices.len() != join_cols.len() {
2299            continue;
2300        }
2301        for row in 0..batch.num_rows() {
2302            let mut key = Vec::with_capacity(right_indices.len());
2303            let mut valid = true;
2304            for &ci in &right_indices {
2305                let col = batch.column(ci);
2306                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2307                    if vids.is_null(row) {
2308                        valid = false;
2309                        break;
2310                    }
2311                    key.push(vids.value(row));
2312                } else {
2313                    valid = false;
2314                    break;
2315                }
2316            }
2317            if valid {
2318                banned.insert(key);
2319            }
2320        }
2321    }
2322
2323    if banned.is_empty() {
2324        return Ok(batches);
2325    }
2326
2327    // Filter body batches: keep rows where composite left key NOT IN banned.
2328    let mut result = Vec::new();
2329    for batch in batches {
2330        let left_indices: Vec<usize> = join_cols
2331            .iter()
2332            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2333            .collect();
2334        if left_indices.len() != join_cols.len() {
2335            result.push(batch);
2336            continue;
2337        }
2338        let all_u64 = left_indices.iter().all(|&ci| {
2339            batch
2340                .column(ci)
2341                .as_any()
2342                .downcast_ref::<UInt64Array>()
2343                .is_some()
2344        });
2345        if !all_u64 {
2346            result.push(batch);
2347            continue;
2348        }
2349
2350        let keep: Vec<bool> = (0..batch.num_rows())
2351            .map(|row| {
2352                let mut key = Vec::with_capacity(left_indices.len());
2353                for &ci in &left_indices {
2354                    let vids = batch
2355                        .column(ci)
2356                        .as_any()
2357                        .downcast_ref::<UInt64Array>()
2358                        .unwrap();
2359                    if vids.is_null(row) {
2360                        return true; // null keys are never banned
2361                    }
2362                    key.push(vids.value(row));
2363                }
2364                !banned.contains(&key)
2365            })
2366            .collect();
2367        let keep_arr = BooleanArray::from(keep);
2368        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2369        if filtered.num_rows() > 0 {
2370            result.push(filtered);
2371        }
2372    }
2373    Ok(result)
2374}
2375
2376/// Multiply `__prob_complement_*` columns into the rule's PROB column and clean up.
2377///
2378/// After IS NOT probabilistic complement semantics have added `__prob_complement_*`
2379/// columns to clause results, this function:
2380/// 1. Computes the product of all complement factor columns
2381/// 2. Multiplies the product into the existing PROB column (if any)
2382/// 3. Removes the internal `__prob_complement_*` columns from the output
2383///
2384/// If the rule has no PROB column, complement columns are simply removed
2385/// (the complement information is discarded and IS NOT acts as a keep-all).
2386pub fn multiply_prob_factors(
2387    batches: Vec<RecordBatch>,
2388    prob_col: Option<&str>,
2389    complement_cols: &[String],
2390) -> datafusion::error::Result<Vec<RecordBatch>> {
2391    use arrow_array::{Array as _, Float64Array};
2392
2393    let mut result = Vec::with_capacity(batches.len());
2394
2395    for batch in batches {
2396        if batch.num_rows() == 0 {
2397            // Remove complement columns from empty batches
2398            let keep: Vec<usize> = batch
2399                .schema()
2400                .fields()
2401                .iter()
2402                .enumerate()
2403                .filter(|(_, f)| !complement_cols.contains(f.name()))
2404                .map(|(i, _)| i)
2405                .collect();
2406            let fields: Vec<_> = keep
2407                .iter()
2408                .map(|&i| batch.schema().field(i).clone())
2409                .collect();
2410            let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
2411            let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2412            result.push(
2413                RecordBatch::try_new(schema, cols).map_err(|e| {
2414                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
2415                })?,
2416            );
2417            continue;
2418        }
2419
2420        let num_rows = batch.num_rows();
2421
2422        // 1. Compute product of all complement factors
2423        let mut combined = vec![1.0f64; num_rows];
2424        for col_name in complement_cols {
2425            if let Ok(idx) = batch.schema().index_of(col_name) {
2426                let arr = batch
2427                    .column(idx)
2428                    .as_any()
2429                    .downcast_ref::<Float64Array>()
2430                    .ok_or_else(|| {
2431                        datafusion::error::DataFusionError::Internal(format!(
2432                            "Expected Float64 for complement column {col_name}"
2433                        ))
2434                    })?;
2435                for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
2436                    if !arr.is_null(i) {
2437                        *val *= arr.value(i);
2438                    }
2439                }
2440            }
2441        }
2442
2443        // 2. If there's a PROB column, multiply combined into it
2444        let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
2445            if let Ok(idx) = batch.schema().index_of(prob_name) {
2446                let arr = batch
2447                    .column(idx)
2448                    .as_any()
2449                    .downcast_ref::<Float64Array>()
2450                    .ok_or_else(|| {
2451                        datafusion::error::DataFusionError::Internal(format!(
2452                            "Expected Float64 for PROB column {prob_name}"
2453                        ))
2454                    })?;
2455                (0..num_rows)
2456                    .map(|i| {
2457                        if arr.is_null(i) {
2458                            combined[i]
2459                        } else {
2460                            arr.value(i) * combined[i]
2461                        }
2462                    })
2463                    .collect()
2464            } else {
2465                combined
2466            }
2467        } else {
2468            combined
2469        };
2470
2471        let new_prob_array: arrow_array::ArrayRef =
2472            std::sync::Arc::new(Float64Array::from(final_prob));
2473
2474        // 3. Build output: replace PROB column, remove complement columns
2475        let mut fields = Vec::new();
2476        let mut columns = Vec::new();
2477
2478        for (idx, field) in batch.schema().fields().iter().enumerate() {
2479            if complement_cols.contains(field.name()) {
2480                continue;
2481            }
2482            if prob_col.is_some_and(|p| field.name() == p) {
2483                fields.push(field.clone());
2484                columns.push(new_prob_array.clone());
2485            } else {
2486                fields.push(field.clone());
2487                columns.push(batch.column(idx).clone());
2488            }
2489        }
2490
2491        let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2492        result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
2493    }
2494
2495    Ok(result)
2496}
2497
2498/// Update derived scan handles before evaluating a rule's clause bodies.
2499///
2500/// For self-references: inject delta (semi-naive optimization).
2501/// For cross-references: inject full facts.
2502fn update_derived_scan_handles(
2503    registry: &DerivedScanRegistry,
2504    states: &[FixpointState],
2505    current_rule_idx: usize,
2506    rules: &[FixpointRulePlan],
2507) {
2508    let current_rule_name = &rules[current_rule_idx].name;
2509
2510    for entry in &registry.entries {
2511        // Find the state for this entry's rule
2512        let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
2513        let Some(source_idx) = source_state_idx else {
2514            continue;
2515        };
2516
2517        let is_self = entry.rule_name == *current_rule_name;
2518        let data = if is_self {
2519            // Self-ref: inject delta for semi-naive
2520            states[source_idx].all_delta().to_vec()
2521        } else {
2522            // Cross-ref: inject full facts
2523            states[source_idx].all_facts().to_vec()
2524        };
2525
2526        // If empty, write an empty batch so the scan returns zero rows
2527        let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
2528            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
2529        } else {
2530            data
2531        };
2532
2533        let mut guard = entry.data.write();
2534        *guard = data;
2535    }
2536}
2537
2538// ---------------------------------------------------------------------------
2539// DerivedScanExec — physical plan that reads from shared data at execution time
2540// ---------------------------------------------------------------------------
2541
2542/// Physical plan for `LocyDerivedScan` that reads from a shared `Arc<RwLock>` at
2543/// execution time (not at plan creation time).
2544///
2545/// This is critical for fixpoint iteration: the data handle is updated between
2546/// iterations, and each re-execution of the subplan must read the latest data.
2547pub struct DerivedScanExec {
2548    data: Arc<RwLock<Vec<RecordBatch>>>,
2549    schema: SchemaRef,
2550    properties: PlanProperties,
2551}
2552
2553impl DerivedScanExec {
2554    pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
2555        let properties = compute_plan_properties(Arc::clone(&schema));
2556        Self {
2557            data,
2558            schema,
2559            properties,
2560        }
2561    }
2562}
2563
2564impl fmt::Debug for DerivedScanExec {
2565    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2566        f.debug_struct("DerivedScanExec")
2567            .field("schema", &self.schema)
2568            .finish()
2569    }
2570}
2571
2572impl DisplayAs for DerivedScanExec {
2573    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2574        write!(f, "DerivedScanExec")
2575    }
2576}
2577
2578impl ExecutionPlan for DerivedScanExec {
2579    fn name(&self) -> &str {
2580        "DerivedScanExec"
2581    }
2582    fn as_any(&self) -> &dyn Any {
2583        self
2584    }
2585    fn schema(&self) -> SchemaRef {
2586        Arc::clone(&self.schema)
2587    }
2588    fn properties(&self) -> &PlanProperties {
2589        &self.properties
2590    }
2591    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2592        vec![]
2593    }
2594    fn with_new_children(
2595        self: Arc<Self>,
2596        _children: Vec<Arc<dyn ExecutionPlan>>,
2597    ) -> DFResult<Arc<dyn ExecutionPlan>> {
2598        Ok(self)
2599    }
2600    fn execute(
2601        &self,
2602        _partition: usize,
2603        _context: Arc<TaskContext>,
2604    ) -> DFResult<SendableRecordBatchStream> {
2605        let batches = {
2606            let guard = self.data.read();
2607            if guard.is_empty() {
2608                vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
2609            } else {
2610                guard.clone()
2611            }
2612        };
2613        Ok(Box::pin(MemoryStream::try_new(
2614            batches,
2615            Arc::clone(&self.schema),
2616            None,
2617        )?))
2618    }
2619}
2620
2621// ---------------------------------------------------------------------------
2622// InMemoryExec — wrapper to feed Vec<RecordBatch> into operator chains
2623// ---------------------------------------------------------------------------
2624
2625/// Simple in-memory execution plan that serves pre-computed batches.
2626///
2627/// Used internally to feed fixpoint results into post-fixpoint operator chains
2628/// (FOLD, BEST BY). Not exported — only used within this module.
2629struct InMemoryExec {
2630    batches: Vec<RecordBatch>,
2631    schema: SchemaRef,
2632    properties: PlanProperties,
2633}
2634
2635impl InMemoryExec {
2636    fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
2637        let properties = compute_plan_properties(Arc::clone(&schema));
2638        Self {
2639            batches,
2640            schema,
2641            properties,
2642        }
2643    }
2644}
2645
2646impl fmt::Debug for InMemoryExec {
2647    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2648        f.debug_struct("InMemoryExec")
2649            .field("num_batches", &self.batches.len())
2650            .field("schema", &self.schema)
2651            .finish()
2652    }
2653}
2654
2655impl DisplayAs for InMemoryExec {
2656    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2657        write!(f, "InMemoryExec: batches={}", self.batches.len())
2658    }
2659}
2660
2661impl ExecutionPlan for InMemoryExec {
2662    fn name(&self) -> &str {
2663        "InMemoryExec"
2664    }
2665    fn as_any(&self) -> &dyn Any {
2666        self
2667    }
2668    fn schema(&self) -> SchemaRef {
2669        Arc::clone(&self.schema)
2670    }
2671    fn properties(&self) -> &PlanProperties {
2672        &self.properties
2673    }
2674    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2675        vec![]
2676    }
2677    fn with_new_children(
2678        self: Arc<Self>,
2679        _children: Vec<Arc<dyn ExecutionPlan>>,
2680    ) -> DFResult<Arc<dyn ExecutionPlan>> {
2681        Ok(self)
2682    }
2683    fn execute(
2684        &self,
2685        _partition: usize,
2686        _context: Arc<TaskContext>,
2687    ) -> DFResult<SendableRecordBatchStream> {
2688        Ok(Box::pin(MemoryStream::try_new(
2689            self.batches.clone(),
2690            Arc::clone(&self.schema),
2691            None,
2692        )?))
2693    }
2694}
2695
2696// ---------------------------------------------------------------------------
2697// Post-fixpoint chain — FOLD and BEST BY on converged facts
2698// ---------------------------------------------------------------------------
2699
2700/// Apply post-fixpoint operators (FOLD, BEST BY, PRIORITY) to converged facts.
2701pub(crate) async fn apply_post_fixpoint_chain(
2702    facts: Vec<RecordBatch>,
2703    rule: &FixpointRulePlan,
2704    task_ctx: &Arc<TaskContext>,
2705    strict_probability_domain: bool,
2706    probability_epsilon: f64,
2707) -> DFResult<Vec<RecordBatch>> {
2708    if !rule.has_fold && !rule.has_best_by && !rule.has_priority {
2709        return Ok(facts);
2710    }
2711
2712    // Wrap facts in InMemoryExec.
2713    // Prefer the actual batch schema (from physical execution) over the
2714    // pre-computed yield_schema, which may have wrong inferred types
2715    // (e.g. Float64 for a string property).
2716    let schema = facts
2717        .iter()
2718        .find(|b| b.num_rows() > 0)
2719        .map(|b| b.schema())
2720        .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
2721    let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema));
2722
2723    // Apply PRIORITY first — keeps only rows with max __priority per KEY group,
2724    // then strips the __priority column from output.
2725    // Must run before FOLD so that the __priority column is still present.
2726    let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
2727        let priority_schema = input.schema();
2728        let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
2729            datafusion::common::DataFusionError::Internal(
2730                "PRIORITY rule missing __priority column".to_string(),
2731            )
2732        })?;
2733        Arc::new(PriorityExec::new(
2734            input,
2735            rule.key_column_indices.clone(),
2736            priority_idx,
2737        ))
2738    } else {
2739        input
2740    };
2741
2742    // Apply FOLD
2743    let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
2744        Arc::new(FoldExec::new(
2745            current,
2746            rule.key_column_indices.clone(),
2747            rule.fold_bindings.clone(),
2748            strict_probability_domain,
2749            probability_epsilon,
2750        ))
2751    } else {
2752        current
2753    };
2754
2755    // Apply BEST BY
2756    let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
2757        Arc::new(BestByExec::new(
2758            current,
2759            rule.key_column_indices.clone(),
2760            rule.best_by_criteria.clone(),
2761            rule.deterministic,
2762        ))
2763    } else {
2764        current
2765    };
2766
2767    collect_all_partitions(&current, Arc::clone(task_ctx)).await
2768}
2769
2770// ---------------------------------------------------------------------------
2771// FixpointExec — DataFusion ExecutionPlan
2772// ---------------------------------------------------------------------------
2773
2774/// DataFusion `ExecutionPlan` that drives semi-naive fixpoint iteration.
2775///
2776/// Has no physical children: clause bodies are re-planned from logical plans
2777/// on each iteration (same pattern as `RecursiveCTEExec` and `GraphApplyExec`).
2778pub struct FixpointExec {
2779    rules: Vec<FixpointRulePlan>,
2780    max_iterations: usize,
2781    timeout: Duration,
2782    graph_ctx: Arc<GraphExecutionContext>,
2783    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2784    storage: Arc<StorageManager>,
2785    schema_info: Arc<UniSchema>,
2786    params: HashMap<String, Value>,
2787    derived_scan_registry: Arc<DerivedScanRegistry>,
2788    output_schema: SchemaRef,
2789    properties: PlanProperties,
2790    metrics: ExecutionPlanMetricsSet,
2791    max_derived_bytes: usize,
2792    /// Optional provenance tracker populated during fixpoint iteration.
2793    derivation_tracker: Option<Arc<ProvenanceStore>>,
2794    /// Shared slot written with per-rule iteration counts after convergence.
2795    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2796    strict_probability_domain: bool,
2797    probability_epsilon: f64,
2798    exact_probability: bool,
2799    max_bdd_variables: usize,
2800    /// Shared slot for runtime warnings collected during fixpoint iteration.
2801    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2802    /// Shared slot for groups where BDD fell back to independence mode.
2803    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2804    /// When > 0, retain at most this many proofs per fact (top-k provenance).
2805    top_k_proofs: usize,
2806}
2807
2808impl fmt::Debug for FixpointExec {
2809    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2810        f.debug_struct("FixpointExec")
2811            .field("rules_count", &self.rules.len())
2812            .field("max_iterations", &self.max_iterations)
2813            .field("timeout", &self.timeout)
2814            .field("output_schema", &self.output_schema)
2815            .field("max_derived_bytes", &self.max_derived_bytes)
2816            .finish_non_exhaustive()
2817    }
2818}
2819
2820impl FixpointExec {
2821    /// Create a new `FixpointExec`.
2822    #[expect(
2823        clippy::too_many_arguments,
2824        reason = "FixpointExec configuration needs all context"
2825    )]
2826    pub fn new(
2827        rules: Vec<FixpointRulePlan>,
2828        max_iterations: usize,
2829        timeout: Duration,
2830        graph_ctx: Arc<GraphExecutionContext>,
2831        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2832        storage: Arc<StorageManager>,
2833        schema_info: Arc<UniSchema>,
2834        params: HashMap<String, Value>,
2835        derived_scan_registry: Arc<DerivedScanRegistry>,
2836        output_schema: SchemaRef,
2837        max_derived_bytes: usize,
2838        derivation_tracker: Option<Arc<ProvenanceStore>>,
2839        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2840        strict_probability_domain: bool,
2841        probability_epsilon: f64,
2842        exact_probability: bool,
2843        max_bdd_variables: usize,
2844        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2845        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2846        top_k_proofs: usize,
2847    ) -> Self {
2848        let properties = compute_plan_properties(Arc::clone(&output_schema));
2849        Self {
2850            rules,
2851            max_iterations,
2852            timeout,
2853            graph_ctx,
2854            session_ctx,
2855            storage,
2856            schema_info,
2857            params,
2858            derived_scan_registry,
2859            output_schema,
2860            properties,
2861            metrics: ExecutionPlanMetricsSet::new(),
2862            max_derived_bytes,
2863            derivation_tracker,
2864            iteration_counts,
2865            strict_probability_domain,
2866            probability_epsilon,
2867            exact_probability,
2868            max_bdd_variables,
2869            warnings_slot,
2870            approximate_slot,
2871            top_k_proofs,
2872        }
2873    }
2874
2875    /// Returns the shared iteration counts slot for post-execution inspection.
2876    pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
2877        Arc::clone(&self.iteration_counts)
2878    }
2879}
2880
2881impl DisplayAs for FixpointExec {
2882    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2883        write!(
2884            f,
2885            "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
2886            self.rules
2887                .iter()
2888                .map(|r| r.name.as_str())
2889                .collect::<Vec<_>>()
2890                .join(", "),
2891            self.max_iterations,
2892            self.timeout,
2893        )
2894    }
2895}
2896
2897impl ExecutionPlan for FixpointExec {
2898    fn name(&self) -> &str {
2899        "FixpointExec"
2900    }
2901
2902    fn as_any(&self) -> &dyn Any {
2903        self
2904    }
2905
2906    fn schema(&self) -> SchemaRef {
2907        Arc::clone(&self.output_schema)
2908    }
2909
2910    fn properties(&self) -> &PlanProperties {
2911        &self.properties
2912    }
2913
2914    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2915        // No physical children — clause bodies are re-planned each iteration
2916        vec![]
2917    }
2918
2919    fn with_new_children(
2920        self: Arc<Self>,
2921        children: Vec<Arc<dyn ExecutionPlan>>,
2922    ) -> DFResult<Arc<dyn ExecutionPlan>> {
2923        if !children.is_empty() {
2924            return Err(datafusion::error::DataFusionError::Plan(
2925                "FixpointExec has no children".to_string(),
2926            ));
2927        }
2928        Ok(self)
2929    }
2930
2931    fn execute(
2932        &self,
2933        partition: usize,
2934        _context: Arc<TaskContext>,
2935    ) -> DFResult<SendableRecordBatchStream> {
2936        let metrics = BaselineMetrics::new(&self.metrics, partition);
2937
2938        // Clone all fields for the async closure
2939        let rules = self
2940            .rules
2941            .iter()
2942            .map(|r| {
2943                // We need to clone the FixpointRulePlan, but it contains LogicalPlan
2944                // which doesn't implement Clone traditionally. However, our LogicalPlan
2945                // does implement Clone since it's an enum.
2946                FixpointRulePlan {
2947                    name: r.name.clone(),
2948                    clauses: r
2949                        .clauses
2950                        .iter()
2951                        .map(|c| FixpointClausePlan {
2952                            body_logical: c.body_logical.clone(),
2953                            is_ref_bindings: c.is_ref_bindings.clone(),
2954                            priority: c.priority,
2955                            along_bindings: c.along_bindings.clone(),
2956                        })
2957                        .collect(),
2958                    yield_schema: Arc::clone(&r.yield_schema),
2959                    key_column_indices: r.key_column_indices.clone(),
2960                    priority: r.priority,
2961                    has_fold: r.has_fold,
2962                    fold_bindings: r.fold_bindings.clone(),
2963                    has_best_by: r.has_best_by,
2964                    best_by_criteria: r.best_by_criteria.clone(),
2965                    has_priority: r.has_priority,
2966                    deterministic: r.deterministic,
2967                    prob_column_name: r.prob_column_name.clone(),
2968                }
2969            })
2970            .collect();
2971
2972        let max_iterations = self.max_iterations;
2973        let timeout = self.timeout;
2974        let graph_ctx = Arc::clone(&self.graph_ctx);
2975        let session_ctx = Arc::clone(&self.session_ctx);
2976        let storage = Arc::clone(&self.storage);
2977        let schema_info = Arc::clone(&self.schema_info);
2978        let params = self.params.clone();
2979        let registry = Arc::clone(&self.derived_scan_registry);
2980        let output_schema = Arc::clone(&self.output_schema);
2981        let max_derived_bytes = self.max_derived_bytes;
2982        let derivation_tracker = self.derivation_tracker.clone();
2983        let iteration_counts = Arc::clone(&self.iteration_counts);
2984        let strict_probability_domain = self.strict_probability_domain;
2985        let probability_epsilon = self.probability_epsilon;
2986        let exact_probability = self.exact_probability;
2987        let max_bdd_variables = self.max_bdd_variables;
2988        let warnings_slot = Arc::clone(&self.warnings_slot);
2989        let approximate_slot = Arc::clone(&self.approximate_slot);
2990        let top_k_proofs = self.top_k_proofs;
2991
2992        let fut = async move {
2993            run_fixpoint_loop(
2994                rules,
2995                max_iterations,
2996                timeout,
2997                graph_ctx,
2998                session_ctx,
2999                storage,
3000                schema_info,
3001                params,
3002                registry,
3003                output_schema,
3004                max_derived_bytes,
3005                derivation_tracker,
3006                iteration_counts,
3007                strict_probability_domain,
3008                probability_epsilon,
3009                exact_probability,
3010                max_bdd_variables,
3011                warnings_slot,
3012                approximate_slot,
3013                top_k_proofs,
3014            )
3015            .await
3016        };
3017
3018        Ok(Box::pin(FixpointStream {
3019            state: FixpointStreamState::Running(Box::pin(fut)),
3020            schema: Arc::clone(&self.output_schema),
3021            metrics,
3022        }))
3023    }
3024
3025    fn metrics(&self) -> Option<MetricsSet> {
3026        Some(self.metrics.clone_inner())
3027    }
3028}
3029
3030// ---------------------------------------------------------------------------
3031// FixpointStream — async state machine for streaming results
3032// ---------------------------------------------------------------------------
3033
3034enum FixpointStreamState {
3035    /// Fixpoint loop is running.
3036    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
3037    /// Emitting accumulated result batches one at a time.
3038    Emitting(Vec<RecordBatch>, usize),
3039    /// All batches emitted.
3040    Done,
3041}
3042
3043struct FixpointStream {
3044    state: FixpointStreamState,
3045    schema: SchemaRef,
3046    metrics: BaselineMetrics,
3047}
3048
3049impl Stream for FixpointStream {
3050    type Item = DFResult<RecordBatch>;
3051
3052    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3053        let this = self.get_mut();
3054        loop {
3055            match &mut this.state {
3056                FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
3057                    Poll::Ready(Ok(batches)) => {
3058                        if batches.is_empty() {
3059                            this.state = FixpointStreamState::Done;
3060                            return Poll::Ready(None);
3061                        }
3062                        this.state = FixpointStreamState::Emitting(batches, 0);
3063                        // Loop to emit first batch
3064                    }
3065                    Poll::Ready(Err(e)) => {
3066                        this.state = FixpointStreamState::Done;
3067                        return Poll::Ready(Some(Err(e)));
3068                    }
3069                    Poll::Pending => return Poll::Pending,
3070                },
3071                FixpointStreamState::Emitting(batches, idx) => {
3072                    if *idx >= batches.len() {
3073                        this.state = FixpointStreamState::Done;
3074                        return Poll::Ready(None);
3075                    }
3076                    let batch = batches[*idx].clone();
3077                    *idx += 1;
3078                    this.metrics.record_output(batch.num_rows());
3079                    return Poll::Ready(Some(Ok(batch)));
3080                }
3081                FixpointStreamState::Done => return Poll::Ready(None),
3082            }
3083        }
3084    }
3085}
3086
3087impl RecordBatchStream for FixpointStream {
3088    fn schema(&self) -> SchemaRef {
3089        Arc::clone(&self.schema)
3090    }
3091}
3092
3093// ---------------------------------------------------------------------------
3094// Unit tests
3095// ---------------------------------------------------------------------------
3096
3097#[cfg(test)]
3098mod tests {
3099    use super::*;
3100    use arrow_array::{Float64Array, Int64Array, StringArray};
3101    use arrow_schema::{DataType, Field, Schema};
3102
3103    fn test_schema() -> SchemaRef {
3104        Arc::new(Schema::new(vec![
3105            Field::new("name", DataType::Utf8, true),
3106            Field::new("value", DataType::Int64, true),
3107        ]))
3108    }
3109
3110    fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
3111        RecordBatch::try_new(
3112            test_schema(),
3113            vec![
3114                Arc::new(StringArray::from(
3115                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3116                )),
3117                Arc::new(Int64Array::from(values.to_vec())),
3118            ],
3119        )
3120        .unwrap()
3121    }
3122
3123    // --- FixpointState dedup tests ---
3124
3125    #[tokio::test]
3126    async fn test_fixpoint_state_empty_facts_adds_all() {
3127        let schema = test_schema();
3128        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3129
3130        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3131        let changed = state.merge_delta(vec![batch], None).await.unwrap();
3132
3133        assert!(changed);
3134        assert_eq!(state.all_facts().len(), 1);
3135        assert_eq!(state.all_facts()[0].num_rows(), 3);
3136        assert_eq!(state.all_delta().len(), 1);
3137        assert_eq!(state.all_delta()[0].num_rows(), 3);
3138    }
3139
3140    #[tokio::test]
3141    async fn test_fixpoint_state_exact_duplicates_excluded() {
3142        let schema = test_schema();
3143        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3144
3145        let batch1 = make_batch(&["a", "b"], &[1, 2]);
3146        state.merge_delta(vec![batch1], None).await.unwrap();
3147
3148        // Same rows again
3149        let batch2 = make_batch(&["a", "b"], &[1, 2]);
3150        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3151        assert!(!changed);
3152        assert!(
3153            state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
3154        );
3155    }
3156
3157    #[tokio::test]
3158    async fn test_fixpoint_state_partial_overlap() {
3159        let schema = test_schema();
3160        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3161
3162        let batch1 = make_batch(&["a", "b"], &[1, 2]);
3163        state.merge_delta(vec![batch1], None).await.unwrap();
3164
3165        // "a":1 is duplicate, "c":3 is new
3166        let batch2 = make_batch(&["a", "c"], &[1, 3]);
3167        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3168        assert!(changed);
3169
3170        // Delta should have only "c":3
3171        let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
3172        assert_eq!(delta_rows, 1);
3173
3174        // Total facts: a:1, b:2, c:3
3175        let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
3176        assert_eq!(total_rows, 3);
3177    }
3178
3179    #[tokio::test]
3180    async fn test_fixpoint_state_convergence() {
3181        let schema = test_schema();
3182        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3183
3184        let batch = make_batch(&["a"], &[1]);
3185        state.merge_delta(vec![batch], None).await.unwrap();
3186
3187        // Empty candidates → converged
3188        let changed = state.merge_delta(vec![], None).await.unwrap();
3189        assert!(!changed);
3190        assert!(state.is_converged());
3191    }
3192
3193    // --- RowDedupState tests ---
3194
3195    #[test]
3196    fn test_row_dedup_persistent_across_calls() {
3197        // RowDedupState should remember rows from the first call so the second
3198        // call does not re-accept them (O(M) per iteration, no facts re-scan).
3199        let schema = test_schema();
3200        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3201
3202        let batch1 = make_batch(&["a", "b"], &[1, 2]);
3203        let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
3204        // First call: both rows are new.
3205        let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
3206        assert_eq!(rows1, 2);
3207
3208        // Second call with same rows: seen set already has them → empty delta.
3209        let batch2 = make_batch(&["a", "b"], &[1, 2]);
3210        let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
3211        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3212        assert_eq!(rows2, 0);
3213
3214        // Third call with one old + one new: only the new row is returned.
3215        let batch3 = make_batch(&["a", "c"], &[1, 3]);
3216        let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
3217        let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
3218        assert_eq!(rows3, 1);
3219    }
3220
3221    #[test]
3222    fn test_row_dedup_null_handling() {
3223        use arrow_array::StringArray;
3224        use arrow_schema::{DataType, Field, Schema};
3225
3226        let schema: SchemaRef = Arc::new(Schema::new(vec![
3227            Field::new("a", DataType::Utf8, true),
3228            Field::new("b", DataType::Int64, true),
3229        ]));
3230        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3231
3232        // Two rows: (NULL, 1) and (NULL, 1) — same NULLs → duplicate.
3233        let batch_nulls = RecordBatch::try_new(
3234            Arc::clone(&schema),
3235            vec![
3236                Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
3237                Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
3238            ],
3239        )
3240        .unwrap();
3241        let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
3242        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3243        assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
3244
3245        // (NULL, 2) — NULL in same col but different non-null col → distinct.
3246        let batch_diff = RecordBatch::try_new(
3247            Arc::clone(&schema),
3248            vec![
3249                Arc::new(StringArray::from(vec![None::<&str>])),
3250                Arc::new(arrow_array::Int64Array::from(vec![2i64])),
3251            ],
3252        )
3253        .unwrap();
3254        let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
3255        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3256        assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
3257    }
3258
3259    #[test]
3260    fn test_row_dedup_within_candidate_dedup() {
3261        // Duplicate rows within a single candidate batch should be collapsed to one.
3262        let schema = test_schema();
3263        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3264
3265        // Batch with three rows: a:1, a:1, b:2 — "a:1" appears twice.
3266        let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
3267        let delta = rd.compute_delta(&[batch], &schema).unwrap();
3268        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3269        assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
3270    }
3271
3272    // --- Float rounding tests ---
3273
3274    #[test]
3275    fn test_round_float_columns_near_duplicates() {
3276        let schema = Arc::new(Schema::new(vec![
3277            Field::new("name", DataType::Utf8, true),
3278            Field::new("dist", DataType::Float64, true),
3279        ]));
3280        let batch = RecordBatch::try_new(
3281            Arc::clone(&schema),
3282            vec![
3283                Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
3284                Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
3285            ],
3286        )
3287        .unwrap();
3288
3289        let rounded = round_float_columns(&[batch]);
3290        assert_eq!(rounded.len(), 1);
3291        let col = rounded[0]
3292            .column(1)
3293            .as_any()
3294            .downcast_ref::<Float64Array>()
3295            .unwrap();
3296        // Both should round to same value
3297        assert_eq!(col.value(0), col.value(1));
3298    }
3299
3300    // --- DerivedScanRegistry tests ---
3301
3302    #[test]
3303    fn test_registry_write_read_round_trip() {
3304        let schema = test_schema();
3305        let data = Arc::new(RwLock::new(Vec::new()));
3306        let mut reg = DerivedScanRegistry::new();
3307        reg.add(DerivedScanEntry {
3308            scan_index: 0,
3309            rule_name: "reachable".into(),
3310            is_self_ref: true,
3311            data: Arc::clone(&data),
3312            schema: Arc::clone(&schema),
3313        });
3314
3315        let batch = make_batch(&["x"], &[42]);
3316        reg.write_data(0, vec![batch.clone()]);
3317
3318        let entry = reg.get(0).unwrap();
3319        let guard = entry.data.read();
3320        assert_eq!(guard.len(), 1);
3321        assert_eq!(guard[0].num_rows(), 1);
3322    }
3323
3324    #[test]
3325    fn test_registry_entries_for_rule() {
3326        let schema = test_schema();
3327        let mut reg = DerivedScanRegistry::new();
3328        reg.add(DerivedScanEntry {
3329            scan_index: 0,
3330            rule_name: "r1".into(),
3331            is_self_ref: true,
3332            data: Arc::new(RwLock::new(Vec::new())),
3333            schema: Arc::clone(&schema),
3334        });
3335        reg.add(DerivedScanEntry {
3336            scan_index: 1,
3337            rule_name: "r2".into(),
3338            is_self_ref: false,
3339            data: Arc::new(RwLock::new(Vec::new())),
3340            schema: Arc::clone(&schema),
3341        });
3342        reg.add(DerivedScanEntry {
3343            scan_index: 2,
3344            rule_name: "r1".into(),
3345            is_self_ref: false,
3346            data: Arc::new(RwLock::new(Vec::new())),
3347            schema: Arc::clone(&schema),
3348        });
3349
3350        assert_eq!(reg.entries_for_rule("r1").len(), 2);
3351        assert_eq!(reg.entries_for_rule("r2").len(), 1);
3352        assert_eq!(reg.entries_for_rule("r3").len(), 0);
3353    }
3354
3355    // --- MonotonicAggState tests ---
3356
3357    #[test]
3358    fn test_monotonic_agg_update_and_stability() {
3359        use crate::query::df_graph::locy_fold::FoldAggKind;
3360
3361        let bindings = vec![MonotonicFoldBinding {
3362            fold_name: "total".into(),
3363            kind: FoldAggKind::Sum,
3364            input_col_index: 1,
3365        }];
3366        let mut agg = MonotonicAggState::new(bindings);
3367
3368        // First update
3369        let batch = make_batch(&["a"], &[10]);
3370        agg.snapshot();
3371        let changed = agg.update(&[0], &[batch], false).unwrap();
3372        assert!(changed);
3373        assert!(!agg.is_stable()); // changed since snapshot
3374
3375        // Snapshot and check stability with no new data
3376        agg.snapshot();
3377        let changed = agg.update(&[0], &[], false).unwrap();
3378        assert!(!changed);
3379        assert!(agg.is_stable());
3380    }
3381
3382    // --- Memory limit test ---
3383
3384    #[tokio::test]
3385    async fn test_memory_limit_exceeded() {
3386        let schema = test_schema();
3387        // Set a tiny limit
3388        let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
3389
3390        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3391        let result = state.merge_delta(vec![batch], None).await;
3392        assert!(result.is_err());
3393        let err = result.unwrap_err().to_string();
3394        assert!(err.contains("memory limit"), "Error was: {}", err);
3395    }
3396
3397    // --- FixpointStream lifecycle test ---
3398
3399    #[tokio::test]
3400    async fn test_fixpoint_stream_emitting() {
3401        use futures::StreamExt;
3402
3403        let schema = test_schema();
3404        let batch1 = make_batch(&["a"], &[1]);
3405        let batch2 = make_batch(&["b"], &[2]);
3406
3407        let metrics = ExecutionPlanMetricsSet::new();
3408        let baseline = BaselineMetrics::new(&metrics, 0);
3409
3410        let mut stream = FixpointStream {
3411            state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
3412            schema,
3413            metrics: baseline,
3414        };
3415
3416        let stream = Pin::new(&mut stream);
3417        let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
3418
3419        assert_eq!(batches.len(), 2);
3420        assert_eq!(batches[0].num_rows(), 1);
3421        assert_eq!(batches[1].num_rows(), 1);
3422    }
3423
3424    // ── MonotonicAggState MNOR/MPROD tests ──────────────────────────────
3425
3426    fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
3427        let schema = Arc::new(Schema::new(vec![
3428            Field::new("name", DataType::Utf8, true),
3429            Field::new("value", DataType::Float64, true),
3430        ]));
3431        RecordBatch::try_new(
3432            schema,
3433            vec![
3434                Arc::new(StringArray::from(
3435                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3436                )),
3437                Arc::new(Float64Array::from(values.to_vec())),
3438            ],
3439        )
3440        .unwrap()
3441    }
3442
3443    fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
3444        use crate::query::df_graph::locy_fold::FoldAggKind;
3445        vec![MonotonicFoldBinding {
3446            fold_name: "prob".into(),
3447            kind: FoldAggKind::Nor,
3448            input_col_index: 1,
3449        }]
3450    }
3451
3452    fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
3453        use crate::query::df_graph::locy_fold::FoldAggKind;
3454        vec![MonotonicFoldBinding {
3455            fold_name: "prob".into(),
3456            kind: FoldAggKind::Prod,
3457            input_col_index: 1,
3458        }]
3459    }
3460
3461    fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
3462        (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
3463    }
3464
3465    #[test]
3466    fn test_monotonic_nor_first_update() {
3467        let mut agg = MonotonicAggState::new(make_nor_binding());
3468        let batch = make_f64_batch(&["a"], &[0.3]);
3469        let changed = agg.update(&[0], &[batch], false).unwrap();
3470        assert!(changed);
3471        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3472        assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
3473    }
3474
3475    #[test]
3476    fn test_monotonic_nor_two_updates() {
3477        // Incremental NOR: acc = 1-(1-0.3)(1-0.5) = 0.65
3478        let mut agg = MonotonicAggState::new(make_nor_binding());
3479        let batch1 = make_f64_batch(&["a"], &[0.3]);
3480        agg.update(&[0], &[batch1], false).unwrap();
3481        let batch2 = make_f64_batch(&["a"], &[0.5]);
3482        agg.update(&[0], &[batch2], false).unwrap();
3483        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3484        assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
3485    }
3486
3487    #[test]
3488    fn test_monotonic_prod_first_update() {
3489        let mut agg = MonotonicAggState::new(make_prod_binding());
3490        let batch = make_f64_batch(&["a"], &[0.6]);
3491        let changed = agg.update(&[0], &[batch], false).unwrap();
3492        assert!(changed);
3493        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3494        assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
3495    }
3496
3497    #[test]
3498    fn test_monotonic_prod_two_updates() {
3499        // Incremental PROD: acc = 0.6 * 0.8 = 0.48
3500        let mut agg = MonotonicAggState::new(make_prod_binding());
3501        let batch1 = make_f64_batch(&["a"], &[0.6]);
3502        agg.update(&[0], &[batch1], false).unwrap();
3503        let batch2 = make_f64_batch(&["a"], &[0.8]);
3504        agg.update(&[0], &[batch2], false).unwrap();
3505        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3506        assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
3507    }
3508
3509    #[test]
3510    fn test_monotonic_nor_stability() {
3511        let mut agg = MonotonicAggState::new(make_nor_binding());
3512        let batch = make_f64_batch(&["a"], &[0.3]);
3513        agg.update(&[0], &[batch], false).unwrap();
3514        agg.snapshot();
3515        let changed = agg.update(&[0], &[], false).unwrap();
3516        assert!(!changed);
3517        assert!(agg.is_stable());
3518    }
3519
3520    #[test]
3521    fn test_monotonic_prod_stability() {
3522        let mut agg = MonotonicAggState::new(make_prod_binding());
3523        let batch = make_f64_batch(&["a"], &[0.6]);
3524        agg.update(&[0], &[batch], false).unwrap();
3525        agg.snapshot();
3526        let changed = agg.update(&[0], &[], false).unwrap();
3527        assert!(!changed);
3528        assert!(agg.is_stable());
3529    }
3530
3531    #[test]
3532    fn test_monotonic_nor_multi_group() {
3533        // (a,0.3),(b,0.5) then (a,0.5),(b,0.2) → a=0.65, b=0.6
3534        let mut agg = MonotonicAggState::new(make_nor_binding());
3535        let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
3536        agg.update(&[0], &[batch1], false).unwrap();
3537        let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
3538        agg.update(&[0], &[batch2], false).unwrap();
3539
3540        let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
3541        let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
3542        assert!(
3543            (val_a - 0.65).abs() < 1e-10,
3544            "expected a=0.65, got {}",
3545            val_a
3546        );
3547        assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
3548    }
3549
3550    #[test]
3551    fn test_monotonic_prod_zero_absorbing() {
3552        // Zero absorbs: once 0.0, all further updates stay 0.0
3553        let mut agg = MonotonicAggState::new(make_prod_binding());
3554        let batch1 = make_f64_batch(&["a"], &[0.5]);
3555        agg.update(&[0], &[batch1], false).unwrap();
3556        let batch2 = make_f64_batch(&["a"], &[0.0]);
3557        agg.update(&[0], &[batch2], false).unwrap();
3558
3559        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3560        assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
3561
3562        // Further updates don't change the absorbing zero
3563        agg.snapshot();
3564        let batch3 = make_f64_batch(&["a"], &[0.5]);
3565        let changed = agg.update(&[0], &[batch3], false).unwrap();
3566        assert!(!changed);
3567        assert!(agg.is_stable());
3568    }
3569
3570    #[test]
3571    fn test_monotonic_nor_clamping() {
3572        // 1.5 clamped to 1.0: acc = 1-(1-0)(1-1) = 1.0
3573        let mut agg = MonotonicAggState::new(make_nor_binding());
3574        let batch = make_f64_batch(&["a"], &[1.5]);
3575        agg.update(&[0], &[batch], false).unwrap();
3576        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3577        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3578    }
3579
3580    #[test]
3581    fn test_monotonic_nor_absorbing() {
3582        // p=1.0 absorbs: 0.3 then 1.0 → 1.0
3583        let mut agg = MonotonicAggState::new(make_nor_binding());
3584        let batch1 = make_f64_batch(&["a"], &[0.3]);
3585        agg.update(&[0], &[batch1], false).unwrap();
3586        let batch2 = make_f64_batch(&["a"], &[1.0]);
3587        agg.update(&[0], &[batch2], false).unwrap();
3588        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3589        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3590    }
3591
3592    // ── MonotonicAggState strict mode tests (Phase 5) ─────────────────────
3593
3594    #[test]
3595    fn test_monotonic_agg_strict_nor_rejects() {
3596        let mut agg = MonotonicAggState::new(make_nor_binding());
3597        let batch = make_f64_batch(&["a"], &[1.5]);
3598        let result = agg.update(&[0], &[batch], true);
3599        assert!(result.is_err());
3600        let err = result.unwrap_err().to_string();
3601        assert!(
3602            err.contains("strict_probability_domain"),
3603            "Expected strict error, got: {}",
3604            err
3605        );
3606    }
3607
3608    #[test]
3609    fn test_monotonic_agg_strict_prod_rejects() {
3610        let mut agg = MonotonicAggState::new(make_prod_binding());
3611        let batch = make_f64_batch(&["a"], &[2.0]);
3612        let result = agg.update(&[0], &[batch], true);
3613        assert!(result.is_err());
3614        let err = result.unwrap_err().to_string();
3615        assert!(
3616            err.contains("strict_probability_domain"),
3617            "Expected strict error, got: {}",
3618            err
3619        );
3620    }
3621
3622    #[test]
3623    fn test_monotonic_agg_strict_accepts_valid() {
3624        let mut agg = MonotonicAggState::new(make_nor_binding());
3625        let batch = make_f64_batch(&["a"], &[0.5]);
3626        let result = agg.update(&[0], &[batch], true);
3627        assert!(result.is_ok());
3628        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3629        assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
3630    }
3631
3632    // ── Complement function unit tests (Phase 4) ──────────────────────────
3633
3634    fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
3635        use arrow_array::UInt64Array;
3636        let schema = Arc::new(Schema::new(vec![
3637            Field::new("vid", DataType::UInt64, true),
3638            Field::new("prob", DataType::Float64, true),
3639        ]));
3640        RecordBatch::try_new(
3641            schema,
3642            vec![
3643                Arc::new(UInt64Array::from(vids.to_vec())),
3644                Arc::new(Float64Array::from(probs.to_vec())),
3645            ],
3646        )
3647        .unwrap()
3648    }
3649
3650    #[test]
3651    fn test_prob_complement_basic() {
3652        // neg has VID=1 with prob=0.7 → complement=0.3; VID=2 absent → complement=1.0
3653        let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
3654        let neg = make_vid_prob_batch(&[1], &[0.7]);
3655        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3656        let result = apply_prob_complement_composite(
3657            vec![body],
3658            &[neg],
3659            &join_cols,
3660            "prob",
3661            "__complement_0",
3662        )
3663        .unwrap();
3664        assert_eq!(result.len(), 1);
3665        let batch = &result[0];
3666        let complement = batch
3667            .column_by_name("__complement_0")
3668            .unwrap()
3669            .as_any()
3670            .downcast_ref::<Float64Array>()
3671            .unwrap();
3672        // VID=1: complement = 1 - 0.7 = 0.3
3673        assert!(
3674            (complement.value(0) - 0.3).abs() < 1e-10,
3675            "expected 0.3, got {}",
3676            complement.value(0)
3677        );
3678        // VID=2: absent from neg → complement = 1.0
3679        assert!(
3680            (complement.value(1) - 1.0).abs() < 1e-10,
3681            "expected 1.0, got {}",
3682            complement.value(1)
3683        );
3684    }
3685
3686    #[test]
3687    fn test_prob_complement_noisy_or_duplicates() {
3688        // neg has VID=1 twice with prob=0.3 and prob=0.5
3689        // Combined via noisy-OR: 1-(1-0.3)(1-0.5) = 0.65
3690        // Complement = 1 - 0.65 = 0.35
3691        let body = make_vid_prob_batch(&[1], &[0.9]);
3692        let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
3693        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3694        let result = apply_prob_complement_composite(
3695            vec![body],
3696            &[neg],
3697            &join_cols,
3698            "prob",
3699            "__complement_0",
3700        )
3701        .unwrap();
3702        let batch = &result[0];
3703        let complement = batch
3704            .column_by_name("__complement_0")
3705            .unwrap()
3706            .as_any()
3707            .downcast_ref::<Float64Array>()
3708            .unwrap();
3709        assert!(
3710            (complement.value(0) - 0.35).abs() < 1e-10,
3711            "expected 0.35, got {}",
3712            complement.value(0)
3713        );
3714    }
3715
3716    #[test]
3717    fn test_prob_complement_empty_neg() {
3718        // Empty neg_facts → body passes through with complement=1.0
3719        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3720        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3721        let result =
3722            apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
3723                .unwrap();
3724        let batch = &result[0];
3725        let complement = batch
3726            .column_by_name("__complement_0")
3727            .unwrap()
3728            .as_any()
3729            .downcast_ref::<Float64Array>()
3730            .unwrap();
3731        for i in 0..2 {
3732            assert!(
3733                (complement.value(i) - 1.0).abs() < 1e-10,
3734                "row {}: expected 1.0, got {}",
3735                i,
3736                complement.value(i)
3737            );
3738        }
3739    }
3740
3741    #[test]
3742    fn test_anti_join_basic() {
3743        // body [1,2,3], neg [2] → result [1,3]
3744        use arrow_array::UInt64Array;
3745        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3746        let neg = make_vid_prob_batch(&[2], &[0.0]);
3747        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3748        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3749        assert_eq!(result.len(), 1);
3750        let batch = &result[0];
3751        assert_eq!(batch.num_rows(), 2);
3752        let vids = batch
3753            .column_by_name("vid")
3754            .unwrap()
3755            .as_any()
3756            .downcast_ref::<UInt64Array>()
3757            .unwrap();
3758        assert_eq!(vids.value(0), 1);
3759        assert_eq!(vids.value(1), 3);
3760    }
3761
3762    #[test]
3763    fn test_anti_join_empty_neg() {
3764        // Empty neg → all rows kept
3765        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3766        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3767        let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
3768        assert_eq!(result.len(), 1);
3769        assert_eq!(result[0].num_rows(), 3);
3770    }
3771
3772    #[test]
3773    fn test_anti_join_all_excluded() {
3774        // neg covers all body rows → empty result
3775        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3776        let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
3777        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3778        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3779        let total: usize = result.iter().map(|b| b.num_rows()).sum();
3780        assert_eq!(total, 0);
3781    }
3782
3783    #[test]
3784    fn test_multiply_prob_single_complement() {
3785        // prob=0.8, complement=0.5 → output prob=0.4; complement col removed
3786        let body = make_vid_prob_batch(&[1], &[0.8]);
3787        // Add a complement column
3788        let complement_arr = Float64Array::from(vec![0.5]);
3789        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3790        cols.push(Arc::new(complement_arr));
3791        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3792        fields.push(Arc::new(Field::new(
3793            "__complement_0",
3794            DataType::Float64,
3795            true,
3796        )));
3797        let schema = Arc::new(Schema::new(fields));
3798        let batch = RecordBatch::try_new(schema, cols).unwrap();
3799
3800        let result =
3801            multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
3802                .unwrap();
3803        assert_eq!(result.len(), 1);
3804        let out = &result[0];
3805        // Complement column should be removed
3806        assert!(out.column_by_name("__complement_0").is_none());
3807        let prob = out
3808            .column_by_name("prob")
3809            .unwrap()
3810            .as_any()
3811            .downcast_ref::<Float64Array>()
3812            .unwrap();
3813        assert!(
3814            (prob.value(0) - 0.4).abs() < 1e-10,
3815            "expected 0.4, got {}",
3816            prob.value(0)
3817        );
3818    }
3819
3820    #[test]
3821    fn test_multiply_prob_multiple_complements() {
3822        // prob=0.8, c1=0.5, c2=0.6 → 0.8×0.5×0.6=0.24
3823        let body = make_vid_prob_batch(&[1], &[0.8]);
3824        let c1 = Float64Array::from(vec![0.5]);
3825        let c2 = Float64Array::from(vec![0.6]);
3826        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3827        cols.push(Arc::new(c1));
3828        cols.push(Arc::new(c2));
3829        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3830        fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
3831        fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
3832        let schema = Arc::new(Schema::new(fields));
3833        let batch = RecordBatch::try_new(schema, cols).unwrap();
3834
3835        let result = multiply_prob_factors(
3836            vec![batch],
3837            Some("prob"),
3838            &["__c1".to_string(), "__c2".to_string()],
3839        )
3840        .unwrap();
3841        let out = &result[0];
3842        assert!(out.column_by_name("__c1").is_none());
3843        assert!(out.column_by_name("__c2").is_none());
3844        let prob = out
3845            .column_by_name("prob")
3846            .unwrap()
3847            .as_any()
3848            .downcast_ref::<Float64Array>()
3849            .unwrap();
3850        assert!(
3851            (prob.value(0) - 0.24).abs() < 1e-10,
3852            "expected 0.24, got {}",
3853            prob.value(0)
3854        );
3855    }
3856
3857    #[test]
3858    fn test_multiply_prob_no_prob_column() {
3859        // No prob column → combined complements become the output
3860        use arrow_array::UInt64Array;
3861        let schema = Arc::new(Schema::new(vec![
3862            Field::new("vid", DataType::UInt64, true),
3863            Field::new("__c1", DataType::Float64, true),
3864        ]));
3865        let batch = RecordBatch::try_new(
3866            schema,
3867            vec![
3868                Arc::new(UInt64Array::from(vec![1u64])),
3869                Arc::new(Float64Array::from(vec![0.7])),
3870            ],
3871        )
3872        .unwrap();
3873
3874        let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
3875        let out = &result[0];
3876        // __c1 should be removed since it's a complement column
3877        assert!(out.column_by_name("__c1").is_none());
3878        // Only vid column remains
3879        assert_eq!(out.num_columns(), 1);
3880    }
3881}