Skip to main content

uni_query/query/df_graph/
locy_fixpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Fixpoint iteration operator for recursive Locy strata.
5//!
6//! `FixpointExec` drives semi-naive evaluation: it repeatedly evaluates the rules
7//! in a recursive stratum, feeding back deltas until no new facts are produced.
8
9use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11    ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12    extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17    ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::planner::LogicalPlan;
22use arrow_array::RecordBatch;
23use arrow_row::{RowConverter, SortField};
24use arrow_schema::SchemaRef;
25use datafusion::common::JoinType;
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
29use datafusion::physical_plan::memory::MemoryStream;
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
32use futures::Stream;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::pin::Pin;
38use std::sync::{Arc, RwLock as StdRwLock};
39use std::task::{Context, Poll};
40use std::time::{Duration, Instant};
41use uni_common::Value;
42use uni_common::core::schema::Schema as UniSchema;
43use uni_cypher::ast::Expr;
44use uni_locy::RuntimeWarning;
45use uni_store::storage::manager::StorageManager;
46
47// ---------------------------------------------------------------------------
48// DerivedScanRegistry — injection point for IS-ref data into subplans
49// ---------------------------------------------------------------------------
50
51/// A single entry in the derived scan registry.
52///
53/// Each entry corresponds to one `LocyDerivedScan` node in the logical plan tree.
54/// The `data` handle is shared with the logical plan node so that writing data here
55/// makes it visible when the subplan is re-planned and executed.
56#[derive(Debug)]
57pub struct DerivedScanEntry {
58    /// Index matching the `scan_index` in `LocyDerivedScan`.
59    pub scan_index: usize,
60    /// Name of the rule this scan reads from.
61    pub rule_name: String,
62    /// Whether this is a self-referential scan (rule references itself).
63    pub is_self_ref: bool,
64    /// Shared data handle — write batches here to inject into subplans.
65    pub data: Arc<RwLock<Vec<RecordBatch>>>,
66    /// Schema of the derived relation.
67    pub schema: SchemaRef,
68}
69
70/// Registry of derived scan handles for fixpoint iteration.
71///
72/// During fixpoint, each clause body may reference derived relations via
73/// `LocyDerivedScan` nodes. The registry maps scan indices to shared data
74/// handles so the fixpoint loop can inject delta/full facts before each
75/// iteration.
76#[derive(Debug, Default)]
77pub struct DerivedScanRegistry {
78    entries: Vec<DerivedScanEntry>,
79}
80
81impl DerivedScanRegistry {
82    /// Create a new empty registry.
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Add an entry to the registry.
88    pub fn add(&mut self, entry: DerivedScanEntry) {
89        self.entries.push(entry);
90    }
91
92    /// Get an entry by scan index.
93    pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
94        self.entries.iter().find(|e| e.scan_index == scan_index)
95    }
96
97    /// Write data into a scan entry's shared handle.
98    pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
99        if let Some(entry) = self.get(scan_index) {
100            let mut guard = entry.data.write();
101            *guard = batches;
102        }
103    }
104
105    /// Get all entries for a given rule name.
106    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
107        self.entries
108            .iter()
109            .filter(|e| e.rule_name == rule_name)
110            .collect()
111    }
112}
113
114// ---------------------------------------------------------------------------
115// MonotonicAggState — tracking monotonic aggregates across iterations
116// ---------------------------------------------------------------------------
117
118/// Monotonic aggregate binding: maps a fold name to its aggregate kind and column.
119#[derive(Debug, Clone)]
120pub struct MonotonicFoldBinding {
121    pub fold_name: String,
122    pub kind: crate::query::df_graph::locy_fold::FoldAggKind,
123    pub input_col_index: usize,
124    /// Column name for name-based resolution (more robust than positional index).
125    pub input_col_name: Option<String>,
126}
127
128/// Tracks monotonic aggregate accumulators across fixpoint iterations.
129///
130/// After each iteration, accumulators are updated and compared to their previous
131/// snapshot. The fixpoint has converged (w.r.t. aggregates) when all accumulators
132/// are stable (no change between iterations).
133#[derive(Debug)]
134pub struct MonotonicAggState {
135    /// Current accumulator values keyed by (group_key, fold_name).
136    accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
137    /// Snapshot from the previous iteration for stability check.
138    prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
139    /// Bindings describing which aggregates to track.
140    bindings: Vec<MonotonicFoldBinding>,
141}
142
143impl MonotonicAggState {
144    /// Create a new monotonic aggregate state.
145    pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
146        Self {
147            accumulators: HashMap::new(),
148            prev_snapshot: HashMap::new(),
149            bindings,
150        }
151    }
152
153    /// Update accumulators with new delta batches.
154    ///
155    /// Returns `true` if any accumulator value changed. When `strict` is
156    /// `true`, Nor/Prod inputs outside `[0, 1]` produce an error instead
157    /// of being clamped.
158    pub fn update(
159        &mut self,
160        key_indices: &[usize],
161        delta_batches: &[RecordBatch],
162        strict: bool,
163    ) -> DFResult<bool> {
164        use crate::query::df_graph::locy_fold::FoldAggKind;
165
166        let mut changed = false;
167        for batch in delta_batches {
168            for row_idx in 0..batch.num_rows() {
169                let group_key = extract_scalar_key(batch, key_indices, row_idx);
170                for binding in &self.bindings {
171                    // Resolve input column by name first, fall back to index.
172                    let idx = binding
173                        .input_col_name
174                        .as_ref()
175                        .and_then(|name| batch.schema().index_of(name).ok())
176                        .unwrap_or(binding.input_col_index);
177                    if idx >= batch.num_columns() {
178                        continue;
179                    }
180                    let col = batch.column(idx);
181                    let val = extract_f64(col.as_ref(), row_idx);
182                    if let Some(val) = val {
183                        let map_key = (group_key.clone(), binding.fold_name.clone());
184                        let entry = self
185                            .accumulators
186                            .entry(map_key)
187                            .or_insert(binding.kind.identity().unwrap_or(0.0));
188                        let old = *entry;
189                        match binding.kind {
190                            FoldAggKind::Sum | FoldAggKind::Count => *entry += val,
191                            FoldAggKind::Max => {
192                                if val > *entry {
193                                    *entry = val;
194                                }
195                            }
196                            FoldAggKind::Min => {
197                                if val < *entry {
198                                    *entry = val;
199                                }
200                            }
201                            FoldAggKind::Nor => {
202                                if strict && !(0.0..=1.0).contains(&val) {
203                                    return Err(datafusion::error::DataFusionError::Execution(
204                                        format!(
205                                            "strict_probability_domain: MNOR input {val} is outside [0, 1]"
206                                        ),
207                                    ));
208                                }
209                                if !strict && !(0.0..=1.0).contains(&val) {
210                                    tracing::warn!(
211                                        "MNOR input {val} outside [0,1], clamped to {}",
212                                        val.clamp(0.0, 1.0)
213                                    );
214                                }
215                                let p = val.clamp(0.0, 1.0);
216                                *entry = 1.0 - (1.0 - *entry) * (1.0 - p);
217                            }
218                            FoldAggKind::Prod => {
219                                if strict && !(0.0..=1.0).contains(&val) {
220                                    return Err(datafusion::error::DataFusionError::Execution(
221                                        format!(
222                                            "strict_probability_domain: MPROD input {val} is outside [0, 1]"
223                                        ),
224                                    ));
225                                }
226                                if !strict && !(0.0..=1.0).contains(&val) {
227                                    tracing::warn!(
228                                        "MPROD input {val} outside [0,1], clamped to {}",
229                                        val.clamp(0.0, 1.0)
230                                    );
231                                }
232                                let p = val.clamp(0.0, 1.0);
233                                *entry *= p;
234                            }
235                            _ => {}
236                        }
237                        if (*entry - old).abs() > f64::EPSILON {
238                            changed = true;
239                        }
240                    }
241                }
242            }
243        }
244        Ok(changed)
245    }
246
247    /// Take a snapshot of current accumulators for stability comparison.
248    pub fn snapshot(&mut self) {
249        self.prev_snapshot = self.accumulators.clone();
250    }
251
252    /// Check if accumulators are stable (no change since last snapshot).
253    pub fn is_stable(&self) -> bool {
254        if self.accumulators.len() != self.prev_snapshot.len() {
255            return false;
256        }
257        for (key, val) in &self.accumulators {
258            match self.prev_snapshot.get(key) {
259                Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
260                _ => return false,
261            }
262        }
263        true
264    }
265
266    /// Test-only accessor for accumulator values.
267    #[cfg(test)]
268    pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
269        self.accumulators.get(key).copied()
270    }
271}
272
273/// Extract f64 value from an Arrow column at a given row index.
274fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
275    if col.is_null(row_idx) {
276        return None;
277    }
278    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
279        Some(arr.value(row_idx))
280    } else {
281        col.as_any()
282            .downcast_ref::<arrow_array::Int64Array>()
283            .map(|arr| arr.value(row_idx) as f64)
284    }
285}
286
287// ---------------------------------------------------------------------------
288// RowDedupState — Arrow RowConverter-based persistent dedup set
289// ---------------------------------------------------------------------------
290
291/// Arrow-native row deduplication using [`RowConverter`].
292///
293/// Unlike the legacy `HashSet<Vec<ScalarKey>>` approach, this struct maintains a
294/// persistent `seen` set across iterations so per-iteration cost is O(M) where M
295/// is the number of candidate rows — the full facts table is never re-scanned.
296struct RowDedupState {
297    converter: RowConverter,
298    seen: HashSet<Box<[u8]>>,
299}
300
301impl RowDedupState {
302    /// Try to build a `RowDedupState` for the given schema.
303    ///
304    /// Returns `None` if any column type is not supported by `RowConverter`
305    /// (triggers legacy fallback).
306    fn try_new(schema: &SchemaRef) -> Option<Self> {
307        let fields: Vec<SortField> = schema
308            .fields()
309            .iter()
310            .map(|f| SortField::new(f.data_type().clone()))
311            .collect();
312        match RowConverter::new(fields) {
313            Ok(converter) => Some(Self {
314                converter,
315                seen: HashSet::new(),
316            }),
317            Err(e) => {
318                tracing::warn!(
319                    "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
320                    e
321                );
322                None
323            }
324        }
325    }
326
327    /// Populate the seen set from existing fact batches.
328    ///
329    /// Used after BEST BY in-loop pruning replaces the fact set, so that delta
330    /// computation in subsequent iterations correctly recognizes surviving facts.
331    fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
332        self.seen.clear();
333        for batch in facts {
334            if batch.num_rows() == 0 {
335                continue;
336            }
337            let arrays: Vec<_> = batch.columns().to_vec();
338            if let Ok(rows) = self.converter.convert_columns(&arrays) {
339                for row_idx in 0..batch.num_rows() {
340                    let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
341                    self.seen.insert(row_bytes);
342                }
343            }
344        }
345    }
346
347    /// Filter `candidates` to only rows not yet seen, updating the persistent set.
348    ///
349    /// Both cross-iteration dedup (rows already accepted in prior iterations) and
350    /// within-batch dedup (duplicate rows in a single candidate batch) are handled
351    /// in a single pass.
352    fn compute_delta(
353        &mut self,
354        candidates: &[RecordBatch],
355        schema: &SchemaRef,
356    ) -> DFResult<Vec<RecordBatch>> {
357        let mut delta_batches = Vec::new();
358        for batch in candidates {
359            if batch.num_rows() == 0 {
360                continue;
361            }
362
363            // Vectorized encoding of all rows in this batch.
364            let arrays: Vec<_> = batch.columns().to_vec();
365            let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
366
367            // One pass: check+insert into persistent seen set.
368            let mut keep = Vec::with_capacity(batch.num_rows());
369            for row_idx in 0..batch.num_rows() {
370                let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
371                keep.push(self.seen.insert(row_bytes));
372            }
373
374            let keep_mask = arrow_array::BooleanArray::from(keep);
375            let new_cols = batch
376                .columns()
377                .iter()
378                .map(|col| {
379                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
380                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
381                    })
382                })
383                .collect::<DFResult<Vec<_>>>()?;
384
385            if new_cols.first().is_some_and(|c| !c.is_empty()) {
386                let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
387                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
388                })?;
389                delta_batches.push(filtered);
390            }
391        }
392        Ok(delta_batches)
393    }
394}
395
396// ---------------------------------------------------------------------------
397// FixpointState — per-rule delta tracking during fixpoint iteration
398// ---------------------------------------------------------------------------
399
400/// Per-rule state for fixpoint iteration.
401///
402/// Tracks accumulated facts and the delta (new facts from the latest iteration).
403/// Deduplication uses Arrow [`RowConverter`] with a persistent seen set (O(M) per
404/// iteration) when supported, with a legacy `HashSet<Vec<ScalarKey>>` fallback.
405pub struct FixpointState {
406    rule_name: String,
407    facts: Vec<RecordBatch>,
408    delta: Vec<RecordBatch>,
409    schema: SchemaRef,
410    key_column_indices: Vec<usize>,
411    /// KEY column names for recomputing indices after schema reconciliation.
412    key_column_names: Vec<String>,
413    /// All column indices for full-row dedup (legacy path only).
414    all_column_indices: Vec<usize>,
415    /// Running total of facts bytes for memory limit tracking.
416    facts_bytes: usize,
417    /// Maximum bytes allowed for this derived relation.
418    max_derived_bytes: usize,
419    /// Optional monotonic aggregate tracking.
420    monotonic_agg: Option<MonotonicAggState>,
421    /// Arrow RowConverter-based dedup state; `None` triggers legacy fallback.
422    row_dedup: Option<RowDedupState>,
423    /// Whether strict probability domain checks are enabled.
424    strict_probability_domain: bool,
425}
426
427impl FixpointState {
428    /// Create a new fixpoint state for a rule.
429    pub fn new(
430        rule_name: String,
431        schema: SchemaRef,
432        key_column_indices: Vec<usize>,
433        max_derived_bytes: usize,
434        monotonic_agg: Option<MonotonicAggState>,
435        strict_probability_domain: bool,
436    ) -> Self {
437        let num_cols = schema.fields().len();
438        let row_dedup = RowDedupState::try_new(&schema);
439        let key_column_names: Vec<String> = key_column_indices
440            .iter()
441            .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
442            .collect();
443        Self {
444            rule_name,
445            facts: Vec::new(),
446            delta: Vec::new(),
447            schema,
448            key_column_indices,
449            key_column_names,
450            all_column_indices: (0..num_cols).collect(),
451            facts_bytes: 0,
452            max_derived_bytes,
453            monotonic_agg,
454            row_dedup,
455            strict_probability_domain,
456        }
457    }
458
459    /// Reconcile the pre-computed schema with the actual physical plan output.
460    ///
461    /// `infer_expr_type` may guess wrong (e.g. `Property → Float64` for a
462    /// string column).  When the first real batch arrives with a different
463    /// schema, update ours so that `RowDedupState` / `RecordBatch::try_new`
464    /// use the correct types.
465    fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
466        if self.schema.fields() != actual_schema.fields() {
467            tracing::debug!(
468                rule = %self.rule_name,
469                "Reconciling fixpoint schema from physical plan output",
470            );
471            self.schema = Arc::clone(actual_schema);
472            self.row_dedup = RowDedupState::try_new(&self.schema);
473            // Recompute key_column_indices from stored KEY column names.
474            // Without this, FoldExec groups by wrong columns when the
475            // physical plan reorders columns vs the pre-inferred schema.
476            let new_indices: Vec<usize> = self
477                .key_column_names
478                .iter()
479                .filter_map(|name| actual_schema.index_of(name).ok())
480                .collect();
481            if new_indices.len() == self.key_column_names.len() {
482                self.key_column_indices = new_indices;
483            }
484            // else: not all KEY columns found in new schema — keep original indices
485            let num_cols = actual_schema.fields().len();
486            self.all_column_indices = (0..num_cols).collect();
487        }
488    }
489
490    /// Merge candidate rows into facts, computing delta (truly new rows).
491    ///
492    /// Returns `true` if any new facts were added.
493    pub async fn merge_delta(
494        &mut self,
495        candidates: Vec<RecordBatch>,
496        task_ctx: Option<Arc<TaskContext>>,
497    ) -> DFResult<bool> {
498        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
499            self.delta.clear();
500            return Ok(false);
501        }
502
503        // Reconcile schema from the first non-empty candidate batch.
504        // The physical plan's output types are authoritative over the
505        // planner's inferred types.
506        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
507            self.reconcile_schema(&first.schema());
508        }
509
510        // Round floats for stable dedup
511        let candidates = round_float_columns(&candidates);
512
513        // Compute delta: rows in candidates not already in facts
514        let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
515
516        if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
517            self.delta.clear();
518            // Update monotonic aggs even with empty delta (for stability check)
519            if let Some(ref mut agg) = self.monotonic_agg {
520                agg.snapshot();
521            }
522            return Ok(false);
523        }
524
525        // Check memory limit
526        let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
527        if self.facts_bytes + delta_bytes > self.max_derived_bytes {
528            return Err(datafusion::error::DataFusionError::Execution(
529                LocyRuntimeError::MemoryLimitExceeded {
530                    rule: self.rule_name.clone(),
531                    bytes: self.facts_bytes + delta_bytes,
532                    limit: self.max_derived_bytes,
533                }
534                .to_string(),
535            ));
536        }
537
538        // Update monotonic aggs
539        if let Some(ref mut agg) = self.monotonic_agg {
540            agg.snapshot();
541            agg.update(
542                &self.key_column_indices,
543                &delta,
544                self.strict_probability_domain,
545            )?;
546        }
547
548        // Append delta to facts
549        self.facts_bytes += delta_bytes;
550        self.facts.extend(delta.iter().cloned());
551        self.delta = delta;
552
553        Ok(true)
554    }
555
556    /// Dispatch to vectorized LeftAntiJoin, Arrow RowConverter dedup, or legacy ScalarKey dedup.
557    ///
558    /// Priority order:
559    /// 1. `arrow_left_anti_dedup` when `total_existing >= DEDUP_ANTI_JOIN_THRESHOLD` and task_ctx available.
560    /// 2. `RowDedupState` (persistent HashSet, O(M) per iteration) when schema is supported.
561    /// 3. `compute_delta_legacy` (rebuilds from facts, fallback for unsupported column types).
562    async fn compute_delta(
563        &mut self,
564        candidates: &[RecordBatch],
565        task_ctx: Option<&Arc<TaskContext>>,
566    ) -> DFResult<Vec<RecordBatch>> {
567        let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
568        if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
569            && let Some(ctx) = task_ctx
570        {
571            return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
572                .await;
573        }
574        if let Some(ref mut rd) = self.row_dedup {
575            rd.compute_delta(candidates, &self.schema)
576        } else {
577            self.compute_delta_legacy(candidates)
578        }
579    }
580
581    /// Legacy dedup: rebuild a `HashSet<Vec<ScalarKey>>` from all facts each call.
582    ///
583    /// Used as fallback when `RowConverter` does not support the schema's column types.
584    fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
585        // Build set of existing fact row keys (ALL columns)
586        let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
587        for batch in &self.facts {
588            for row_idx in 0..batch.num_rows() {
589                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
590                existing.insert(key);
591            }
592        }
593
594        let mut delta_batches = Vec::new();
595        for batch in candidates {
596            if batch.num_rows() == 0 {
597                continue;
598            }
599            // Filter to only new rows
600            let mut keep = Vec::with_capacity(batch.num_rows());
601            for row_idx in 0..batch.num_rows() {
602                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
603                keep.push(!existing.contains(&key));
604            }
605
606            // Also dedup within the candidate batch itself
607            for (row_idx, kept) in keep.iter_mut().enumerate() {
608                if *kept {
609                    let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
610                    if !existing.insert(key) {
611                        *kept = false;
612                    }
613                }
614            }
615
616            let keep_mask = arrow_array::BooleanArray::from(keep);
617            let new_rows = batch
618                .columns()
619                .iter()
620                .map(|col| {
621                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
622                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
623                    })
624                })
625                .collect::<DFResult<Vec<_>>>()?;
626
627            if new_rows.first().is_some_and(|c| !c.is_empty()) {
628                let filtered =
629                    RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
630                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
631                    })?;
632                delta_batches.push(filtered);
633            }
634        }
635
636        Ok(delta_batches)
637    }
638
639    /// Check if this rule has converged (no new facts and aggs stable).
640    pub fn is_converged(&self) -> bool {
641        let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
642        let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
643        delta_empty && agg_stable
644    }
645
646    /// Get all accumulated facts.
647    pub fn all_facts(&self) -> &[RecordBatch] {
648        &self.facts
649    }
650
651    /// Get the delta from the latest iteration.
652    pub fn all_delta(&self) -> &[RecordBatch] {
653        &self.delta
654    }
655
656    /// Consume self and return facts.
657    pub fn into_facts(self) -> Vec<RecordBatch> {
658        self.facts
659    }
660
661    /// Merge candidates using BEST BY semantics.
662    ///
663    /// Combines existing facts with new candidates, keeping only the best row
664    /// per KEY group according to `sort_criteria`. Returns `true` if the
665    /// best-per-KEY fact set actually changed (a genuinely better value was
666    /// found or a new KEY appeared).
667    ///
668    /// This replaces `merge_delta` for rules with BEST BY, enabling convergence
669    /// on cyclic graphs where dominated ALONG values would otherwise produce an
670    /// unbounded stream of "new" full-row facts.
671    pub fn merge_best_by(
672        &mut self,
673        candidates: Vec<RecordBatch>,
674        sort_criteria: &[SortCriterion],
675    ) -> DFResult<bool> {
676        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
677            self.delta.clear();
678            return Ok(false);
679        }
680
681        // Reconcile schema from the first non-empty candidate batch.
682        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
683            self.reconcile_schema(&first.schema());
684        }
685
686        // Round floats for stable dedup.
687        let candidates = round_float_columns(&candidates);
688
689        // Snapshot existing best-per-KEY facts for change detection.
690        let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
691            self.build_key_criteria_map(sort_criteria);
692
693        // Concat existing facts + new candidates.
694        let mut all_batches = self.facts.clone();
695        all_batches.extend(candidates);
696        let all_batches: Vec<_> = all_batches
697            .into_iter()
698            .filter(|b| b.num_rows() > 0)
699            .collect();
700        if all_batches.is_empty() {
701            self.delta.clear();
702            return Ok(false);
703        }
704
705        let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
706            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
707
708        if combined.num_rows() == 0 {
709            self.delta.clear();
710            return Ok(false);
711        }
712
713        // Sort by KEY ASC then criteria, so the best row per KEY group comes
714        // first.
715        let mut sort_columns = Vec::new();
716        for &ki in &self.key_column_indices {
717            if ki >= combined.num_columns() {
718                continue;
719            }
720            sort_columns.push(arrow::compute::SortColumn {
721                values: Arc::clone(combined.column(ki)),
722                options: Some(arrow::compute::SortOptions {
723                    descending: false,
724                    nulls_first: false,
725                }),
726            });
727        }
728        for criterion in sort_criteria {
729            if criterion.col_index >= combined.num_columns() {
730                continue;
731            }
732            sort_columns.push(arrow::compute::SortColumn {
733                values: Arc::clone(combined.column(criterion.col_index)),
734                options: Some(arrow::compute::SortOptions {
735                    descending: !criterion.ascending,
736                    nulls_first: criterion.nulls_first,
737                }),
738            });
739        }
740
741        let sorted_indices =
742            arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
743        let sorted_columns: Vec<_> = combined
744            .columns()
745            .iter()
746            .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
747            .collect::<Result<Vec<_>, _>>()
748            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
749        let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
750            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
751
752        // Dedup: keep first (best) row per KEY group.
753        let mut keep_indices: Vec<u32> = Vec::new();
754        let mut prev_key: Option<Vec<ScalarKey>> = None;
755        for row_idx in 0..sorted.num_rows() {
756            let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
757            let is_new_group = match &prev_key {
758                None => true,
759                Some(prev) => *prev != key,
760            };
761            if is_new_group {
762                keep_indices.push(row_idx as u32);
763                prev_key = Some(key);
764            }
765        }
766
767        let keep_array = arrow_array::UInt32Array::from(keep_indices);
768        let output_columns: Vec<_> = sorted
769            .columns()
770            .iter()
771            .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
772            .collect::<Result<Vec<_>, _>>()
773            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
774        let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
775            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
776
777        // Detect whether the best-per-KEY set actually changed.
778        let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
779            let mut map = HashMap::new();
780            for row_idx in 0..pruned.num_rows() {
781                let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
782                let criteria: Vec<ScalarKey> = sort_criteria
783                    .iter()
784                    .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
785                    .collect();
786                map.insert(key, criteria);
787            }
788            map
789        };
790        let changed = old_best != new_best;
791
792        tracing::debug!(
793            rule = %self.rule_name,
794            old_keys = old_best.len(),
795            new_keys = new_best.len(),
796            changed = changed,
797            "BEST BY merge"
798        );
799
800        // Replace facts with the pruned set.
801        self.facts_bytes = batch_byte_size(&pruned);
802        self.facts = vec![pruned];
803        if changed {
804            // Delta is conceptually the new/improved facts, but since we
805            // replaced the entire set, just mark delta non-empty.
806            self.delta = self.facts.clone();
807        } else {
808            self.delta.clear();
809        }
810
811        // Rebuild row dedup from pruned facts for consistency.
812        self.row_dedup = RowDedupState::try_new(&self.schema);
813        if let Some(ref mut rd) = self.row_dedup {
814            rd.ingest_existing(&self.facts, &self.schema);
815        }
816
817        Ok(changed)
818    }
819
820    /// Build a map from KEY column values to sort criteria values.
821    fn build_key_criteria_map(
822        &self,
823        sort_criteria: &[SortCriterion],
824    ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
825        let mut map = HashMap::new();
826        for batch in &self.facts {
827            for row_idx in 0..batch.num_rows() {
828                let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
829                let criteria: Vec<ScalarKey> = sort_criteria
830                    .iter()
831                    .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
832                    .collect();
833                map.insert(key, criteria);
834            }
835        }
836        map
837    }
838}
839
840/// Estimate byte size of a RecordBatch.
841fn batch_byte_size(batch: &RecordBatch) -> usize {
842    batch
843        .columns()
844        .iter()
845        .map(|col| col.get_buffer_memory_size())
846        .sum()
847}
848
849// ---------------------------------------------------------------------------
850// Float rounding for stable dedup
851// ---------------------------------------------------------------------------
852
853/// Round all Float64 columns to 12 decimal places for stable dedup.
854fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
855    batches
856        .iter()
857        .map(|batch| {
858            let schema = batch.schema();
859            let has_float = schema
860                .fields()
861                .iter()
862                .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
863            if !has_float {
864                return batch.clone();
865            }
866
867            let columns: Vec<arrow_array::ArrayRef> = batch
868                .columns()
869                .iter()
870                .enumerate()
871                .map(|(i, col)| {
872                    if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
873                        let arr = col
874                            .as_any()
875                            .downcast_ref::<arrow_array::Float64Array>()
876                            .unwrap();
877                        let rounded: arrow_array::Float64Array = arr
878                            .iter()
879                            .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
880                            .collect();
881                        Arc::new(rounded) as arrow_array::ArrayRef
882                    } else {
883                        Arc::clone(col)
884                    }
885                })
886                .collect();
887
888            RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
889        })
890        .collect()
891}
892
893// ---------------------------------------------------------------------------
894// LeftAntiJoin delta deduplication
895// ---------------------------------------------------------------------------
896
897/// Row threshold above which the vectorized Arrow LeftAntiJoin dedup path is used.
898///
899/// Below this threshold the persistent `RowDedupState` HashSet is O(M) and
900/// avoids rebuilding the existing-row set; above it DataFusion's vectorized
901/// HashJoinExec is more cache-efficient.
902const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
903
904/// Deduplicate `candidates` against `existing` using DataFusion's HashJoinExec.
905///
906/// Returns rows in `candidates` that do not appear in `existing` (LeftAnti semantics).
907/// `null_equals_null = true` so NULLs are treated as equal for dedup purposes.
908async fn arrow_left_anti_dedup(
909    candidates: Vec<RecordBatch>,
910    existing: &[RecordBatch],
911    schema: &SchemaRef,
912    task_ctx: &Arc<TaskContext>,
913) -> DFResult<Vec<RecordBatch>> {
914    if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
915        return Ok(candidates);
916    }
917
918    let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
919    let right: Arc<dyn ExecutionPlan> =
920        Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
921
922    let on: Vec<(
923        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
924        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
925    )> = schema
926        .fields()
927        .iter()
928        .enumerate()
929        .map(|(i, field)| {
930            let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
931                datafusion::physical_plan::expressions::Column::new(field.name(), i),
932            );
933            let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
934                datafusion::physical_plan::expressions::Column::new(field.name(), i),
935            );
936            (l, r)
937        })
938        .collect();
939
940    if on.is_empty() {
941        return Ok(vec![]);
942    }
943
944    let join = HashJoinExec::try_new(
945        left,
946        right,
947        on,
948        None,
949        &JoinType::LeftAnti,
950        None,
951        PartitionMode::CollectLeft,
952        datafusion::common::NullEquality::NullEqualsNull,
953    )?;
954
955    let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
956    collect_all_partitions(&join_arc, task_ctx.clone()).await
957}
958
959// ---------------------------------------------------------------------------
960// Plan types for fixpoint rules
961// ---------------------------------------------------------------------------
962
963/// IS-ref binding: a reference from a clause body to a derived relation.
964#[derive(Debug, Clone)]
965pub struct IsRefBinding {
966    /// Index into the DerivedScanRegistry.
967    pub derived_scan_index: usize,
968    /// Name of the rule being referenced.
969    pub rule_name: String,
970    /// Whether this is a self-reference (rule references itself).
971    pub is_self_ref: bool,
972    /// Whether this is a negated reference (NOT IS).
973    pub negated: bool,
974    /// For negated IS-refs: `(left_body_col, right_derived_col)` pairs for anti-join filtering.
975    ///
976    /// `left_body_col` is the VID column in the clause body (e.g., `"n._vid"`);
977    /// `right_derived_col` is the corresponding KEY column in the negated rule's facts (e.g., `"n"`).
978    /// Empty for non-negated IS-refs.
979    pub anti_join_cols: Vec<(String, String)>,
980    /// Whether the target rule has a PROB column.
981    pub target_has_prob: bool,
982    /// Name of the PROB column in the target rule, if any.
983    pub target_prob_col: Option<String>,
984    /// `(body_col, derived_col)` pairs for provenance tracking.
985    ///
986    /// Used by shared-proof detection to find which source facts a derived row
987    /// consumed. Populated for all IS-refs (not just negated ones).
988    pub provenance_join_cols: Vec<(String, String)>,
989}
990
991/// A single clause (body) within a fixpoint rule.
992#[derive(Debug)]
993pub struct FixpointClausePlan {
994    /// The logical plan for the clause body.
995    pub body_logical: LogicalPlan,
996    /// IS-ref bindings used by this clause.
997    pub is_ref_bindings: Vec<IsRefBinding>,
998    /// Priority value for this clause (if PRIORITY semantics apply).
999    pub priority: Option<i64>,
1000    /// ALONG binding variable names propagated from the planner.
1001    pub along_bindings: Vec<String>,
1002}
1003
1004/// Physical plan for a single rule in a fixpoint stratum.
1005#[derive(Debug)]
1006pub struct FixpointRulePlan {
1007    /// Rule name.
1008    pub name: String,
1009    /// Clause bodies (each evaluates to candidate rows).
1010    pub clauses: Vec<FixpointClausePlan>,
1011    /// Output schema for this rule's derived relation.
1012    pub yield_schema: SchemaRef,
1013    /// Indices of KEY columns within yield_schema.
1014    pub key_column_indices: Vec<usize>,
1015    /// Priority value (if PRIORITY semantics apply).
1016    pub priority: Option<i64>,
1017    /// Whether this rule has FOLD semantics.
1018    pub has_fold: bool,
1019    /// FOLD bindings for post-fixpoint aggregation.
1020    pub fold_bindings: Vec<FoldBinding>,
1021    /// Post-FOLD filter expressions (HAVING semantics).
1022    pub having: Vec<Expr>,
1023    /// Whether this rule has BEST BY semantics.
1024    pub has_best_by: bool,
1025    /// BEST BY sort criteria for post-fixpoint selection.
1026    pub best_by_criteria: Vec<SortCriterion>,
1027    /// Whether this rule has PRIORITY semantics.
1028    pub has_priority: bool,
1029    /// Whether BEST BY should apply a deterministic secondary sort for
1030    /// tie-breaking. When false, tied rows are selected non-deterministically
1031    /// (faster but not repeatable across runs).
1032    pub deterministic: bool,
1033    /// Name of the PROB column in this rule's yield schema, if any.
1034    pub prob_column_name: Option<String>,
1035}
1036
1037// ---------------------------------------------------------------------------
1038// run_fixpoint_loop — the core semi-naive iteration algorithm
1039// ---------------------------------------------------------------------------
1040
1041/// Run the semi-naive fixpoint iteration loop.
1042///
1043/// Evaluates all rules in a stratum repeatedly, feeding deltas back through
1044/// derived scan handles until convergence or limits are reached.
1045#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1046async fn run_fixpoint_loop(
1047    rules: Vec<FixpointRulePlan>,
1048    max_iterations: usize,
1049    timeout: Duration,
1050    graph_ctx: Arc<GraphExecutionContext>,
1051    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1052    storage: Arc<StorageManager>,
1053    schema_info: Arc<UniSchema>,
1054    params: HashMap<String, Value>,
1055    registry: Arc<DerivedScanRegistry>,
1056    output_schema: SchemaRef,
1057    max_derived_bytes: usize,
1058    derivation_tracker: Option<Arc<ProvenanceStore>>,
1059    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1060    strict_probability_domain: bool,
1061    probability_epsilon: f64,
1062    exact_probability: bool,
1063    max_bdd_variables: usize,
1064    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1065    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1066    top_k_proofs: usize,
1067    timeout_flag: Arc<std::sync::atomic::AtomicBool>,
1068) -> DFResult<Vec<RecordBatch>> {
1069    let start = Instant::now();
1070    let task_ctx = session_ctx.read().task_ctx();
1071
1072    // Initialize per-rule state
1073    let mut states: Vec<FixpointState> = rules
1074        .iter()
1075        .map(|rule| {
1076            let monotonic_agg = if !rule.fold_bindings.is_empty() {
1077                let bindings: Vec<MonotonicFoldBinding> = rule
1078                    .fold_bindings
1079                    .iter()
1080                    .map(|fb| MonotonicFoldBinding {
1081                        fold_name: fb.output_name.clone(),
1082                        kind: fb.kind.clone(),
1083                        input_col_index: fb.input_col_index,
1084                        input_col_name: fb.input_col_name.clone(),
1085                    })
1086                    .collect();
1087                Some(MonotonicAggState::new(bindings))
1088            } else {
1089                None
1090            };
1091            FixpointState::new(
1092                rule.name.clone(),
1093                Arc::clone(&rule.yield_schema),
1094                rule.key_column_indices.clone(),
1095                max_derived_bytes,
1096                monotonic_agg,
1097                strict_probability_domain,
1098            )
1099        })
1100        .collect();
1101
1102    // Main iteration loop
1103    let mut converged = false;
1104    let mut total_iters = 0usize;
1105    for iteration in 0..max_iterations {
1106        total_iters = iteration + 1;
1107        tracing::debug!("fixpoint iteration {}", iteration);
1108        let mut any_changed = false;
1109
1110        for rule_idx in 0..rules.len() {
1111            let rule = &rules[rule_idx];
1112
1113            // Update derived scan handles for this rule's clauses
1114            update_derived_scan_handles(&registry, &states, rule_idx, &rules);
1115
1116            // Evaluate clause bodies, tracking per-clause candidates for provenance.
1117            let mut all_candidates = Vec::new();
1118            let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1119            for clause in &rule.clauses {
1120                let mut batches = execute_subplan(
1121                    &clause.body_logical,
1122                    &params,
1123                    &HashMap::new(),
1124                    &graph_ctx,
1125                    &session_ctx,
1126                    &storage,
1127                    &schema_info,
1128                )
1129                .await?;
1130                // Apply negated IS-ref semantics: probabilistic complement or anti-join.
1131                for binding in &clause.is_ref_bindings {
1132                    if binding.negated
1133                        && !binding.anti_join_cols.is_empty()
1134                        && let Some(entry) = registry.get(binding.derived_scan_index)
1135                    {
1136                        let neg_facts = entry.data.read().clone();
1137                        if !neg_facts.is_empty() {
1138                            if binding.target_has_prob && rule.prob_column_name.is_some() {
1139                                // Probabilistic complement: add 1-p column instead of filtering.
1140                                let complement_col =
1141                                    format!("__prob_complement_{}", binding.rule_name);
1142                                if let Some(prob_col) = &binding.target_prob_col {
1143                                    batches = apply_prob_complement_composite(
1144                                        batches,
1145                                        &neg_facts,
1146                                        &binding.anti_join_cols,
1147                                        prob_col,
1148                                        &complement_col,
1149                                    )?;
1150                                } else {
1151                                    // target_has_prob but no prob_col: fall back to anti-join.
1152                                    batches = apply_anti_join_composite(
1153                                        batches,
1154                                        &neg_facts,
1155                                        &binding.anti_join_cols,
1156                                    )?;
1157                                }
1158                            } else {
1159                                // Boolean exclusion: anti-join (existing behavior)
1160                                batches = apply_anti_join_composite(
1161                                    batches,
1162                                    &neg_facts,
1163                                    &binding.anti_join_cols,
1164                                )?;
1165                            }
1166                        }
1167                    }
1168                }
1169                // Multiply complement columns into the PROB column (if any) and clean up
1170                let complement_cols: Vec<String> = if !batches.is_empty() {
1171                    batches[0]
1172                        .schema()
1173                        .fields()
1174                        .iter()
1175                        .filter(|f| f.name().starts_with("__prob_complement_"))
1176                        .map(|f| f.name().clone())
1177                        .collect()
1178                } else {
1179                    vec![]
1180                };
1181                if !complement_cols.is_empty() {
1182                    batches = multiply_prob_factors(
1183                        batches,
1184                        rule.prob_column_name.as_deref(),
1185                        &complement_cols,
1186                    )?;
1187                }
1188
1189                clause_candidates.push(batches.clone());
1190                all_candidates.extend(batches);
1191            }
1192
1193            // Merge candidates into facts.
1194            // For BEST BY rules, use a specialized merge that keeps only the
1195            // best row per KEY group, enabling convergence on cyclic graphs.
1196            let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1197                states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1198            } else {
1199                states[rule_idx]
1200                    .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1201                    .await?
1202            };
1203            if changed {
1204                any_changed = true;
1205                // Record provenance for newly derived facts when tracker is present.
1206                if let Some(ref tracker) = derivation_tracker {
1207                    record_provenance(
1208                        tracker,
1209                        rule,
1210                        &states[rule_idx],
1211                        &clause_candidates,
1212                        iteration,
1213                        &registry,
1214                        top_k_proofs,
1215                    );
1216                }
1217            }
1218        }
1219
1220        // Check convergence
1221        if !any_changed && states.iter().all(|s| s.is_converged()) {
1222            tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1223            converged = true;
1224            break;
1225        }
1226
1227        // Check timeout
1228        if start.elapsed() > timeout {
1229            tracing::warn!(
1230                "fixpoint timeout after {} iterations; returning partial results",
1231                iteration + 1,
1232            );
1233            timeout_flag.store(true, std::sync::atomic::Ordering::Relaxed);
1234            break;
1235        }
1236    }
1237
1238    // Write per-rule iteration counts to the shared slot.
1239    if let Ok(mut counts) = iteration_counts.write() {
1240        for rule in &rules {
1241            counts.insert(rule.name.clone(), total_iters);
1242        }
1243    }
1244
1245    // If we exhausted all iterations without converging, set timeout flag
1246    // and proceed with partial results rather than discarding all work.
1247    if !converged && !timeout_flag.load(std::sync::atomic::Ordering::Relaxed) {
1248        tracing::warn!(
1249            "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1250        );
1251        timeout_flag.store(true, std::sync::atomic::Ordering::Relaxed);
1252    }
1253
1254    // Post-fixpoint processing per rule and collect output
1255    let task_ctx = session_ctx.read().task_ctx();
1256    let mut all_output = Vec::new();
1257
1258    for (rule_idx, state) in states.into_iter().enumerate() {
1259        let rule = &rules[rule_idx];
1260        let mut facts = state.into_facts();
1261        if facts.is_empty() {
1262            continue;
1263        }
1264
1265        // Detect shared proofs before FOLD collapses groups.
1266        let shared_info = if let Some(ref tracker) = derivation_tracker {
1267            detect_shared_lineage(rule, &facts, tracker, &warnings_slot)
1268        } else {
1269            None
1270        };
1271
1272        // Apply BDD for shared groups if exact_probability is enabled.
1273        if exact_probability
1274            && let Some(ref info) = shared_info
1275            && let Some(ref tracker) = derivation_tracker
1276        {
1277            facts = apply_exact_wmc(
1278                facts,
1279                rule,
1280                info,
1281                tracker,
1282                max_bdd_variables,
1283                &warnings_slot,
1284                &approximate_slot,
1285            )?;
1286        }
1287
1288        let processed = apply_post_fixpoint_chain(
1289            facts,
1290            rule,
1291            &task_ctx,
1292            strict_probability_domain,
1293            probability_epsilon,
1294        )
1295        .await?;
1296        all_output.extend(processed);
1297    }
1298
1299    // If no output, return empty batch with output schema
1300    if all_output.is_empty() {
1301        all_output.push(RecordBatch::new_empty(output_schema));
1302    }
1303
1304    Ok(all_output)
1305}
1306
1307// ---------------------------------------------------------------------------
1308// Provenance recording helpers
1309// ---------------------------------------------------------------------------
1310
1311/// Record provenance for all newly derived facts (rows in the current delta).
1312///
1313/// Called after `merge_delta` returns `true`. Attributes each new fact to the
1314/// clause most likely to have produced it, using first-derivation-wins semantics.
1315fn record_provenance(
1316    tracker: &Arc<ProvenanceStore>,
1317    rule: &FixpointRulePlan,
1318    state: &FixpointState,
1319    clause_candidates: &[Vec<RecordBatch>],
1320    iteration: usize,
1321    registry: &Arc<DerivedScanRegistry>,
1322    top_k_proofs: usize,
1323) {
1324    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1325
1326    // Pre-compute base fact probabilities for top-k mode.
1327    let base_probs = if top_k_proofs > 0 {
1328        tracker.base_fact_probs()
1329    } else {
1330        HashMap::new()
1331    };
1332
1333    for delta_batch in state.all_delta() {
1334        for row_idx in 0..delta_batch.num_rows() {
1335            let row_hash = format!(
1336                "{:?}",
1337                extract_scalar_key(delta_batch, &all_indices, row_idx)
1338            )
1339            .into_bytes();
1340            let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1341            let clause_index =
1342                find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1343
1344            let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1345
1346            let proof_probability = if top_k_proofs > 0 {
1347                compute_proof_probability(&support, &base_probs)
1348            } else {
1349                None
1350            };
1351
1352            let entry = ProvenanceAnnotation {
1353                rule_name: rule.name.clone(),
1354                clause_index,
1355                support,
1356                along_values: {
1357                    let along_names: Vec<String> = rule
1358                        .clauses
1359                        .get(clause_index)
1360                        .map(|c| c.along_bindings.clone())
1361                        .unwrap_or_default();
1362                    along_names
1363                        .iter()
1364                        .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1365                        .collect()
1366                },
1367                iteration,
1368                fact_row,
1369                proof_probability,
1370            };
1371            if top_k_proofs > 0 {
1372                tracker.record_top_k(row_hash, entry, top_k_proofs);
1373            } else {
1374                tracker.record(row_hash, entry);
1375            }
1376        }
1377    }
1378}
1379
1380/// Collect IS-ref input facts for a derived row using provenance join columns.
1381///
1382/// For each non-negated IS-ref binding in the clause, extracts body-side key
1383/// values from the delta row and finds matching source rows in the registry.
1384/// Returns a `ProofTerm` for each match (with the source fact hash).
1385fn collect_is_ref_inputs(
1386    rule: &FixpointRulePlan,
1387    clause_index: usize,
1388    delta_batch: &RecordBatch,
1389    row_idx: usize,
1390    registry: &Arc<DerivedScanRegistry>,
1391) -> Vec<ProofTerm> {
1392    let clause = match rule.clauses.get(clause_index) {
1393        Some(c) => c,
1394        None => return vec![],
1395    };
1396
1397    let mut inputs = Vec::new();
1398    let delta_schema = delta_batch.schema();
1399
1400    for binding in &clause.is_ref_bindings {
1401        if binding.negated {
1402            continue;
1403        }
1404        if binding.provenance_join_cols.is_empty() {
1405            continue;
1406        }
1407
1408        // Extract body-side values from the delta row for each provenance join col.
1409        let body_values: Vec<(String, ScalarKey)> = binding
1410            .provenance_join_cols
1411            .iter()
1412            .filter_map(|(body_col, _derived_col)| {
1413                let col_idx = delta_schema
1414                    .fields()
1415                    .iter()
1416                    .position(|f| f.name() == body_col)?;
1417                let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
1418                Some((body_col.clone(), key.into_iter().next()?))
1419            })
1420            .collect();
1421
1422        if body_values.len() != binding.provenance_join_cols.len() {
1423            continue;
1424        }
1425
1426        // Read current data from the registry entry for this IS-ref's rule.
1427        let entry = match registry.get(binding.derived_scan_index) {
1428            Some(e) => e,
1429            None => continue,
1430        };
1431        let source_batches = entry.data.read();
1432        let source_schema = &entry.schema;
1433
1434        // Find matching source rows and hash them.
1435        for src_batch in source_batches.iter() {
1436            let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
1437            for src_row in 0..src_batch.num_rows() {
1438                let matches = binding.provenance_join_cols.iter().enumerate().all(
1439                    |(i, (_body_col, derived_col))| {
1440                        let src_col_idx = source_schema
1441                            .fields()
1442                            .iter()
1443                            .position(|f| f.name() == derived_col);
1444                        match src_col_idx {
1445                            Some(idx) => {
1446                                let src_key = extract_scalar_key(src_batch, &[idx], src_row);
1447                                src_key.first() == Some(&body_values[i].1)
1448                            }
1449                            None => false,
1450                        }
1451                    },
1452                );
1453                if matches {
1454                    let fact_hash = format!(
1455                        "{:?}",
1456                        extract_scalar_key(src_batch, &all_src_indices, src_row)
1457                    )
1458                    .into_bytes();
1459                    inputs.push(ProofTerm {
1460                        source_rule: binding.rule_name.clone(),
1461                        base_fact_id: fact_hash,
1462                    });
1463                }
1464            }
1465        }
1466    }
1467
1468    inputs
1469}
1470
1471// ---------------------------------------------------------------------------
1472// Shared-lineage detection
1473// ---------------------------------------------------------------------------
1474
1475/// Detect KEY groups in a rule's pre-fold facts where recursive derivation
1476/// may violate the independence assumption of MNOR/MPROD.
1477///
1478/// Uses a two-tier strategy:
1479/// 1. **Precise**: If the `ProvenanceStore` has populated `support` for facts
1480///    in the group, we recursively compute lineage (Cui & Widom 2000) and
1481///    check for pairwise overlap. A shared base fact proves a dependency.
1482/// 2. **Structural fallback**: When lineage tracking is unavailable (e.g., the
1483///    IS-ref subject variables were projected away), we check whether any fact
1484///    in a multi-row group was derived by a clause that has IS-ref bindings.
1485///    Recursive derivation through shared relations is a strong signal that
1486///    proof paths may share intermediate nodes.
1487///
1488/// Per-row data collected during shared-lineage detection.
1489#[expect(
1490    dead_code,
1491    reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
1492)]
1493pub(crate) struct SharedGroupRow {
1494    pub fact_hash: Vec<u8>,
1495    pub lineage: HashSet<Vec<u8>>,
1496}
1497
1498/// Information about groups with shared proofs, returned by `detect_shared_lineage`.
1499pub(crate) struct SharedLineageInfo {
1500    /// KEY group → rows with their base fact sets.
1501    pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
1502}
1503
1504/// Build a byte key that uniquely identifies a row across all columns.
1505fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
1506    format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
1507}
1508
1509/// Emits at most one `SharedProbabilisticDependency` warning per rule.
1510/// Returns `Some(SharedLineageInfo)` if any group has shared proofs.
1511fn detect_shared_lineage(
1512    rule: &FixpointRulePlan,
1513    pre_fold_facts: &[RecordBatch],
1514    tracker: &Arc<ProvenanceStore>,
1515    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1516) -> Option<SharedLineageInfo> {
1517    use crate::query::df_graph::locy_fold::FoldAggKind;
1518    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1519
1520    // Only check rules with MNOR/MPROD fold bindings.
1521    let has_prob_fold = rule
1522        .fold_bindings
1523        .iter()
1524        .any(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1525    if !has_prob_fold {
1526        return None;
1527    }
1528
1529    // Group facts by KEY columns.
1530    let key_indices = &rule.key_column_indices;
1531    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1532
1533    let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
1534    for batch in pre_fold_facts {
1535        for row_idx in 0..batch.num_rows() {
1536            let key = extract_scalar_key(batch, key_indices, row_idx);
1537            let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1538            groups.entry(key).or_default().push(fact_hash);
1539        }
1540    }
1541
1542    let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
1543    let mut any_shared = false;
1544
1545    // Check each group with ≥2 rows.
1546    for (key, fact_hashes) in &groups {
1547        if fact_hashes.len() < 2 {
1548            continue;
1549        }
1550
1551        // Tier 1: precise base-fact overlap detection via tracker inputs.
1552        let mut has_inputs = false;
1553        let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
1554        for fh in fact_hashes {
1555            let bases = compute_lineage(fh, tracker, &mut HashSet::new());
1556            if let Some(entry) = tracker.lookup(fh)
1557                && !entry.support.is_empty()
1558            {
1559                has_inputs = true;
1560            }
1561            per_row_bases.push(bases);
1562        }
1563
1564        let shared_found = if has_inputs {
1565            // At least some facts have tracked inputs — do precise comparison.
1566            let mut found = false;
1567            'outer: for i in 0..per_row_bases.len() {
1568                for j in (i + 1)..per_row_bases.len() {
1569                    if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
1570                        found = true;
1571                        break 'outer;
1572                    }
1573                }
1574            }
1575            found
1576        } else {
1577            // Tier 2: structural fallback — check if any fact in the group was
1578            // derived by a clause with IS-ref bindings (recursive derivation).
1579            fact_hashes.iter().any(|fh| {
1580                tracker.lookup(fh).is_some_and(|entry| {
1581                    rule.clauses
1582                        .get(entry.clause_index)
1583                        .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
1584                })
1585            })
1586        };
1587
1588        if shared_found {
1589            any_shared = true;
1590            // Collect the group rows with their base facts for BDD use.
1591            let rows: Vec<SharedGroupRow> = fact_hashes
1592                .iter()
1593                .zip(per_row_bases.into_iter())
1594                .map(|(fh, bases)| SharedGroupRow {
1595                    fact_hash: fh.clone(),
1596                    lineage: bases,
1597                })
1598                .collect();
1599            shared_groups.insert(key.clone(), rows);
1600        }
1601    }
1602
1603    // Phase 5: Cross-group correlation warning.
1604    // Check if any IS-ref input fact appears in multiple KEY groups.
1605    // This is independent of within-group sharing: even rules whose KEY groups
1606    // each have only one post-fold row can exhibit cross-group correlation when
1607    // different groups consume the same IS-ref base fact.
1608    {
1609        let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
1610        for (key, fact_hashes) in &groups {
1611            for fh in fact_hashes {
1612                if let Some(entry) = tracker.lookup(fh) {
1613                    for input in &entry.support {
1614                        input_to_groups
1615                            .entry(input.base_fact_id.clone())
1616                            .or_default()
1617                            .insert(key.clone());
1618                    }
1619                }
1620            }
1621        }
1622        let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
1623        if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
1624            let already_warned = warnings.iter().any(|w| {
1625                w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
1626                    && w.rule_name == rule.name
1627            });
1628            if !already_warned {
1629                warnings.push(RuntimeWarning {
1630                    code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
1631                    message: format!(
1632                        "Rule '{}': IS-ref base facts are shared across different KEY \
1633                         groups. BDD corrects per-group probabilities but cannot account \
1634                         for cross-group correlations.",
1635                        rule.name
1636                    ),
1637                    rule_name: rule.name.clone(),
1638                    variable_count: None,
1639                    key_group: None,
1640                });
1641            }
1642        }
1643    }
1644
1645    if any_shared {
1646        if let Ok(mut warnings) = warnings_slot.write() {
1647            let already_warned = warnings.iter().any(|w| {
1648                w.code == RuntimeWarningCode::SharedProbabilisticDependency
1649                    && w.rule_name == rule.name
1650            });
1651            if !already_warned {
1652                warnings.push(RuntimeWarning {
1653                    code: RuntimeWarningCode::SharedProbabilisticDependency,
1654                    message: format!(
1655                        "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
1656                         share intermediate facts, violating the independence assumption. \
1657                         Results may overestimate probability.",
1658                        rule.name
1659                    ),
1660                    rule_name: rule.name.clone(),
1661                    variable_count: None,
1662                    key_group: None,
1663                });
1664            }
1665        }
1666        Some(SharedLineageInfo { shared_groups })
1667    } else {
1668        None
1669    }
1670}
1671
1672/// Record provenance and detect shared proofs for non-recursive strata.
1673///
1674/// Non-recursive rules are evaluated in a single pass (no fixpoint loop), so
1675/// the regular `record_provenance` + `detect_shared_lineage` path is never hit.
1676/// This function bridges that gap by recording a `ProvenanceAnnotation` for every
1677/// fact produced by each clause and then running the same two-tier detection
1678/// logic used by the recursive path.
1679pub(crate) fn record_and_detect_lineage_nonrecursive(
1680    rule: &FixpointRulePlan,
1681    tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
1682    tracker: &Arc<ProvenanceStore>,
1683    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1684    registry: &Arc<DerivedScanRegistry>,
1685    top_k_proofs: usize,
1686) -> Option<SharedLineageInfo> {
1687    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1688
1689    // Pre-compute base fact probabilities for top-k mode.
1690    let base_probs = if top_k_proofs > 0 {
1691        tracker.base_fact_probs()
1692    } else {
1693        HashMap::new()
1694    };
1695
1696    // Record provenance for each clause's facts.
1697    for (clause_index, batches) in tagged_clause_facts {
1698        for batch in batches {
1699            for row_idx in 0..batch.num_rows() {
1700                let row_hash = fact_hash_key(batch, &all_indices, row_idx);
1701                let fact_row = batch_row_to_value_map(batch, row_idx);
1702
1703                let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
1704
1705                let proof_probability = if top_k_proofs > 0 {
1706                    compute_proof_probability(&support, &base_probs)
1707                } else {
1708                    None
1709                };
1710
1711                let entry = ProvenanceAnnotation {
1712                    rule_name: rule.name.clone(),
1713                    clause_index: *clause_index,
1714                    support,
1715                    along_values: {
1716                        let along_names: Vec<String> = rule
1717                            .clauses
1718                            .get(*clause_index)
1719                            .map(|c| c.along_bindings.clone())
1720                            .unwrap_or_default();
1721                        along_names
1722                            .iter()
1723                            .filter_map(|name| {
1724                                fact_row.get(name).map(|v| (name.clone(), v.clone()))
1725                            })
1726                            .collect()
1727                    },
1728                    iteration: 0,
1729                    fact_row,
1730                    proof_probability,
1731                };
1732                if top_k_proofs > 0 {
1733                    tracker.record_top_k(row_hash, entry, top_k_proofs);
1734                } else {
1735                    tracker.record(row_hash, entry);
1736                }
1737            }
1738        }
1739    }
1740
1741    // Flatten all clause facts and run detection.
1742    let all_facts: Vec<RecordBatch> = tagged_clause_facts
1743        .iter()
1744        .flat_map(|(_, batches)| batches.iter().cloned())
1745        .collect();
1746    detect_shared_lineage(rule, &all_facts, tracker, warnings_slot)
1747}
1748
1749/// Apply exact weighted model counting (WMC) for shared-lineage groups.
1750///
1751/// Replaces multiple rows in groups with shared lineage with a single
1752/// representative row whose PROB column carries the BDD-computed exact
1753/// probability (Sang et al. 2005). For groups that exceed
1754/// `max_bdd_variables`, rows are left unchanged and a `BddLimitExceeded`
1755/// warning is emitted.
1756pub(crate) fn apply_exact_wmc(
1757    pre_fold_facts: Vec<RecordBatch>,
1758    rule: &FixpointRulePlan,
1759    shared_info: &SharedLineageInfo,
1760    tracker: &Arc<ProvenanceStore>,
1761    max_bdd_variables: usize,
1762    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1763    approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1764) -> DFResult<Vec<RecordBatch>> {
1765    use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
1766    use crate::query::df_graph::locy_fold::FoldAggKind;
1767    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1768
1769    // Find the MNOR/MPROD fold binding to know which column to overwrite.
1770    let prob_fold = rule
1771        .fold_bindings
1772        .iter()
1773        .find(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1774    let prob_fold = match prob_fold {
1775        Some(f) => f,
1776        None => return Ok(pre_fold_facts),
1777    };
1778    let semiring_op = if matches!(prob_fold.kind, FoldAggKind::Nor) {
1779        SemiringOp::Disjunction
1780    } else {
1781        SemiringOp::Conjunction
1782    };
1783    let prob_col_idx = prob_fold.input_col_index;
1784    let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
1785
1786    let key_indices = &rule.key_column_indices;
1787    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1788
1789    // Build a set of shared group keys for quick lookup.
1790    let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
1791
1792    // Phase 1: Collect all rows for each shared KEY group across all batches.
1793    // Store (batch_index, row_index) pairs for each group.
1794    struct GroupAccum {
1795        base_facts: Vec<HashSet<Vec<u8>>>,
1796        base_probs: HashMap<Vec<u8>, f64>,
1797        /// First occurrence: (batch_index, row_index) — used as representative.
1798        representative: (usize, usize),
1799        row_locations: Vec<(usize, usize)>,
1800    }
1801
1802    let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
1803    let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); // (batch_idx, row_idx)
1804
1805    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1806        for row_idx in 0..batch.num_rows() {
1807            let key = extract_scalar_key(batch, key_indices, row_idx);
1808            if shared_keys.contains(&key) {
1809                let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1810                let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
1811
1812                let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
1813                    base_facts: Vec::new(),
1814                    base_probs: HashMap::new(),
1815                    representative: (batch_idx, row_idx),
1816                    row_locations: Vec::new(),
1817                });
1818
1819                // Look up probabilities for base facts.
1820                for bf in &bases {
1821                    if !accum.base_probs.contains_key(bf)
1822                        && let Some(entry) = tracker.lookup(bf)
1823                        && let Some(val) = entry.fact_row.get(&prob_col_name)
1824                        && let Some(p) = value_to_f64(val)
1825                    {
1826                        accum.base_probs.insert(bf.clone(), p);
1827                    }
1828                }
1829
1830                accum.base_facts.push(bases);
1831                accum.row_locations.push((batch_idx, row_idx));
1832            } else {
1833                non_shared_rows.push((batch_idx, row_idx));
1834            }
1835        }
1836    }
1837
1838    // Phase 2: Compute BDD for each shared group (across all batches).
1839    // Track which (batch_idx, row_idx) pairs to keep vs drop.
1840    let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
1841    // Map of (batch_idx, row_idx) → overridden PROB value (for BDD-succeeded groups).
1842    let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
1843
1844    // All non-shared rows are kept.
1845    for &loc in &non_shared_rows {
1846        keep_rows.insert(loc);
1847    }
1848
1849    for (key, accum) in &group_accums {
1850        let bdd_result = weighted_model_count(
1851            &accum.base_facts,
1852            &accum.base_probs,
1853            semiring_op,
1854            max_bdd_variables,
1855        );
1856
1857        if bdd_result.approximated {
1858            // Emit BddLimitExceeded warning (one per key group).
1859            if let Ok(mut warnings) = warnings_slot.write() {
1860                let key_desc = format!("{key:?}");
1861                let already_warned = warnings.iter().any(|w| {
1862                    w.code == RuntimeWarningCode::BddLimitExceeded
1863                        && w.rule_name == rule.name
1864                        && w.key_group.as_deref() == Some(&key_desc)
1865                });
1866                if !already_warned {
1867                    warnings.push(RuntimeWarning {
1868                        code: RuntimeWarningCode::BddLimitExceeded,
1869                        message: format!(
1870                            "Rule '{}': BDD variable limit exceeded ({} > {}). \
1871                             Falling back to independence-mode result.",
1872                            rule.name, bdd_result.variable_count, max_bdd_variables
1873                        ),
1874                        rule_name: rule.name.clone(),
1875                        variable_count: Some(bdd_result.variable_count),
1876                        key_group: Some(key_desc),
1877                    });
1878                }
1879            }
1880            if let Ok(mut approx) = approximate_slot.write() {
1881                let key_desc = format!("{key:?}");
1882                approx.entry(rule.name.clone()).or_default().push(key_desc);
1883            }
1884            // Keep all rows unchanged.
1885            for &loc in &accum.row_locations {
1886                keep_rows.insert(loc);
1887            }
1888        } else {
1889            // BDD succeeded: keep one representative row with overridden PROB.
1890            keep_rows.insert(accum.representative);
1891            overrides.insert(accum.representative, bdd_result.probability);
1892        }
1893    }
1894
1895    // Phase 3: Build output batches by filtering kept rows per batch.
1896    let mut result_batches = Vec::new();
1897    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1898        let kept_indices: Vec<usize> = (0..batch.num_rows())
1899            .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
1900            .collect();
1901
1902        if kept_indices.is_empty() {
1903            continue;
1904        }
1905
1906        let indices = arrow::array::UInt32Array::from(
1907            kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
1908        );
1909        let mut columns: Vec<arrow::array::ArrayRef> = batch
1910            .columns()
1911            .iter()
1912            .map(|col| arrow::compute::take(col, &indices, None))
1913            .collect::<Result<Vec<_>, _>>()
1914            .map_err(arrow_err)?;
1915
1916        // Check if any kept rows have PROB overrides.
1917        let override_map: Vec<Option<f64>> = kept_indices
1918            .iter()
1919            .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
1920            .collect();
1921
1922        if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
1923            // Rebuild the PROB column with overrides.
1924            let existing_prob = columns[prob_col_idx]
1925                .as_any()
1926                .downcast_ref::<arrow::array::Float64Array>();
1927            let new_values: Vec<f64> = override_map
1928                .iter()
1929                .enumerate()
1930                .map(|(i, ov)| match ov {
1931                    Some(p) => *p,
1932                    None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
1933                })
1934                .collect();
1935            columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
1936        }
1937
1938        let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
1939        result_batches.push(result_batch);
1940    }
1941
1942    Ok(result_batches)
1943}
1944
1945/// Extract an f64 from a `Value`, supporting Float and Int.
1946fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
1947    match val {
1948        uni_common::Value::Float(f) => Some(*f),
1949        uni_common::Value::Int(i) => Some(*i as f64),
1950        _ => None,
1951    }
1952}
1953
1954/// Compute the lineage of a derived fact (Cui & Widom 2000).
1955///
1956/// Recursively traverses the provenance store to collect the set of base-level
1957/// fact hashes that contribute to this derivation. A base fact is one with no
1958/// IS-ref support (a graph-level fact). Intermediate facts are expanded
1959/// transitively through the store.
1960fn compute_lineage(
1961    fact_hash: &[u8],
1962    tracker: &Arc<ProvenanceStore>,
1963    visited: &mut HashSet<Vec<u8>>,
1964) -> HashSet<Vec<u8>> {
1965    if !visited.insert(fact_hash.to_vec()) {
1966        return HashSet::new(); // Cycle guard.
1967    }
1968
1969    match tracker.lookup(fact_hash) {
1970        Some(entry) if !entry.support.is_empty() => {
1971            let mut bases = HashSet::new();
1972            for input in &entry.support {
1973                let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
1974                bases.extend(child_bases);
1975            }
1976            bases
1977        }
1978        _ => {
1979            // Base fact (no tracker entry or no inputs).
1980            let mut set = HashSet::new();
1981            set.insert(fact_hash.to_vec());
1982            set
1983        }
1984    }
1985}
1986
1987/// Determine which clause produced a given row by checking each clause's candidates.
1988///
1989/// Returns the index of the first clause whose candidates contain a matching row.
1990/// Falls back to 0 if no match is found.
1991fn find_clause_for_row(
1992    delta_batch: &RecordBatch,
1993    row_idx: usize,
1994    all_indices: &[usize],
1995    clause_candidates: &[Vec<RecordBatch>],
1996) -> usize {
1997    let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
1998    for (clause_idx, batches) in clause_candidates.iter().enumerate() {
1999        for batch in batches {
2000            if batch.num_columns() != all_indices.len() {
2001                continue;
2002            }
2003            for r in 0..batch.num_rows() {
2004                if extract_scalar_key(batch, all_indices, r) == target_key {
2005                    return clause_idx;
2006                }
2007            }
2008        }
2009    }
2010    0
2011}
2012
2013/// Convert a single row from a `RecordBatch` at `row_idx` into a `HashMap<String, Value>`.
2014fn batch_row_to_value_map(
2015    batch: &RecordBatch,
2016    row_idx: usize,
2017) -> std::collections::HashMap<String, Value> {
2018    use uni_store::storage::arrow_convert::arrow_to_value;
2019
2020    let schema = batch.schema();
2021    schema
2022        .fields()
2023        .iter()
2024        .enumerate()
2025        .map(|(col_idx, field)| {
2026            let col = batch.column(col_idx);
2027            let val = arrow_to_value(col.as_ref(), row_idx, None);
2028            (field.name().clone(), val)
2029        })
2030        .collect()
2031}
2032
2033/// Filter `batches` to exclude rows where `left_col` VID appears in `neg_facts[right_col]`.
2034///
2035/// Implements anti-join semantics for negated IS-refs (`n IS NOT rule`): keeps only
2036/// rows whose subject VID is NOT present in the negated rule's fully-converged facts.
2037pub fn apply_anti_join(
2038    batches: Vec<RecordBatch>,
2039    neg_facts: &[RecordBatch],
2040    left_col: &str,
2041    right_col: &str,
2042) -> datafusion::error::Result<Vec<RecordBatch>> {
2043    use arrow::compute::filter_record_batch;
2044    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2045
2046    // Collect right-side VIDs from the negated rule's derived facts.
2047    let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2048    for batch in neg_facts {
2049        let Ok(idx) = batch.schema().index_of(right_col) else {
2050            continue;
2051        };
2052        let arr = batch.column(idx);
2053        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2054            continue;
2055        };
2056        for i in 0..vids.len() {
2057            if !vids.is_null(i) {
2058                banned.insert(vids.value(i));
2059            }
2060        }
2061    }
2062
2063    if banned.is_empty() {
2064        return Ok(batches);
2065    }
2066
2067    // Filter body batches: keep rows where left_col NOT IN banned.
2068    let mut result = Vec::new();
2069    for batch in batches {
2070        let Ok(idx) = batch.schema().index_of(left_col) else {
2071            result.push(batch);
2072            continue;
2073        };
2074        let arr = batch.column(idx);
2075        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2076            result.push(batch);
2077            continue;
2078        };
2079        let keep: Vec<bool> = (0..vids.len())
2080            .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2081            .collect();
2082        let keep_arr = BooleanArray::from(keep);
2083        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2084        if filtered.num_rows() > 0 {
2085            result.push(filtered);
2086        }
2087    }
2088    Ok(result)
2089}
2090
2091/// Probabilistic complement for negated IS-refs targeting PROB rules.
2092///
2093/// Instead of filtering out matching VIDs (anti-join), this adds a complement column
2094/// `__prob_complement_{rule_name}` with value `1 - p` for each matching VID, and `1.0`
2095/// for VIDs not present in the negated rule's facts.
2096///
2097/// This implements the probabilistic complement semantics: `IS NOT risk` on a PROB rule
2098/// yields the probability that the entity is NOT risky.
2099pub fn apply_prob_complement(
2100    batches: Vec<RecordBatch>,
2101    neg_facts: &[RecordBatch],
2102    left_col: &str,
2103    right_col: &str,
2104    prob_col: &str,
2105    complement_col_name: &str,
2106) -> datafusion::error::Result<Vec<RecordBatch>> {
2107    use arrow_array::{Array as _, Float64Array, UInt64Array};
2108
2109    // Build VID → probability lookup from negative facts
2110    let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
2111    for batch in neg_facts {
2112        let Ok(vid_idx) = batch.schema().index_of(right_col) else {
2113            continue;
2114        };
2115        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2116            continue;
2117        };
2118        let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
2119            continue;
2120        };
2121        let prob_arr = batch.column(prob_idx);
2122        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2123        for i in 0..vids.len() {
2124            if !vids.is_null(i) {
2125                let p = probs
2126                    .and_then(|arr| {
2127                        if arr.is_null(i) {
2128                            None
2129                        } else {
2130                            Some(arr.value(i))
2131                        }
2132                    })
2133                    .unwrap_or(0.0);
2134                // If multiple facts for same VID, use noisy-OR combination:
2135                // combined = 1 - (1 - existing) * (1 - new)
2136                prob_map
2137                    .entry(vids.value(i))
2138                    .and_modify(|existing| {
2139                        *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2140                    })
2141                    .or_insert(p);
2142            }
2143        }
2144    }
2145
2146    // Add complement column to each batch
2147    let mut result = Vec::new();
2148    for batch in batches {
2149        let Ok(idx) = batch.schema().index_of(left_col) else {
2150            result.push(batch);
2151            continue;
2152        };
2153        let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
2154            result.push(batch);
2155            continue;
2156        };
2157
2158        // Compute complement values: 1 - p for matched VIDs, 1.0 for absent
2159        let complements: Vec<f64> = (0..vids.len())
2160            .map(|i| {
2161                if vids.is_null(i) {
2162                    1.0
2163                } else {
2164                    let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
2165                    1.0 - p
2166                }
2167            })
2168            .collect();
2169
2170        let complement_arr = Float64Array::from(complements);
2171
2172        // Add the complement column to the batch
2173        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2174        columns.push(std::sync::Arc::new(complement_arr));
2175
2176        let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
2177            batch.schema().fields().iter().cloned().collect();
2178        fields.push(std::sync::Arc::new(arrow_schema::Field::new(
2179            complement_col_name,
2180            arrow_schema::DataType::Float64,
2181            true,
2182        )));
2183
2184        let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2185        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2186        result.push(new_batch);
2187    }
2188    Ok(result)
2189}
2190
2191/// Probabilistic complement for composite (multi-column) join keys.
2192///
2193/// Builds a composite key from all `join_cols` right-side columns in
2194/// `neg_facts`, maps each composite key to a probability via noisy-OR
2195/// combination, then adds a single `complement_col_name` column with
2196/// `1 - p` for matched keys and `1.0` for absent keys.
2197pub fn apply_prob_complement_composite(
2198    batches: Vec<RecordBatch>,
2199    neg_facts: &[RecordBatch],
2200    join_cols: &[(String, String)],
2201    prob_col: &str,
2202    complement_col_name: &str,
2203) -> datafusion::error::Result<Vec<RecordBatch>> {
2204    use arrow_array::{Array as _, Float64Array, UInt64Array};
2205
2206    // Build composite-key → probability lookup from negative facts.
2207    let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
2208    for batch in neg_facts {
2209        let right_indices: Vec<usize> = join_cols
2210            .iter()
2211            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2212            .collect();
2213        if right_indices.len() != join_cols.len() {
2214            continue;
2215        }
2216        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2217            continue;
2218        };
2219        let prob_arr = batch.column(prob_idx);
2220        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2221        for row in 0..batch.num_rows() {
2222            let mut key = Vec::with_capacity(right_indices.len());
2223            let mut valid = true;
2224            for &ci in &right_indices {
2225                let col = batch.column(ci);
2226                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2227                    if vids.is_null(row) {
2228                        valid = false;
2229                        break;
2230                    }
2231                    key.push(vids.value(row));
2232                } else {
2233                    valid = false;
2234                    break;
2235                }
2236            }
2237            if !valid {
2238                continue;
2239            }
2240            let p = probs
2241                .and_then(|arr| {
2242                    if arr.is_null(row) {
2243                        None
2244                    } else {
2245                        Some(arr.value(row))
2246                    }
2247                })
2248                .unwrap_or(0.0);
2249            // Noisy-OR combination for duplicate composite keys.
2250            prob_map
2251                .entry(key)
2252                .and_modify(|existing| {
2253                    *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2254                })
2255                .or_insert(p);
2256        }
2257    }
2258
2259    // Add complement column to each batch.
2260    let mut result = Vec::new();
2261    for batch in batches {
2262        let left_indices: Vec<usize> = join_cols
2263            .iter()
2264            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2265            .collect();
2266        if left_indices.len() != join_cols.len() {
2267            result.push(batch);
2268            continue;
2269        }
2270        let all_u64 = left_indices.iter().all(|&ci| {
2271            batch
2272                .column(ci)
2273                .as_any()
2274                .downcast_ref::<UInt64Array>()
2275                .is_some()
2276        });
2277        if !all_u64 {
2278            result.push(batch);
2279            continue;
2280        }
2281
2282        let complements: Vec<f64> = (0..batch.num_rows())
2283            .map(|row| {
2284                let mut key = Vec::with_capacity(left_indices.len());
2285                for &ci in &left_indices {
2286                    let vids = batch
2287                        .column(ci)
2288                        .as_any()
2289                        .downcast_ref::<UInt64Array>()
2290                        .unwrap();
2291                    if vids.is_null(row) {
2292                        return 1.0;
2293                    }
2294                    key.push(vids.value(row));
2295                }
2296                let p = prob_map.get(&key).copied().unwrap_or(0.0);
2297                1.0 - p
2298            })
2299            .collect();
2300
2301        let complement_arr = Float64Array::from(complements);
2302        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2303        columns.push(Arc::new(complement_arr));
2304
2305        let mut fields: Vec<Arc<arrow_schema::Field>> =
2306            batch.schema().fields().iter().cloned().collect();
2307        fields.push(Arc::new(arrow_schema::Field::new(
2308            complement_col_name,
2309            arrow_schema::DataType::Float64,
2310            true,
2311        )));
2312
2313        let new_schema = Arc::new(arrow_schema::Schema::new(fields));
2314        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2315        result.push(new_batch);
2316    }
2317    Ok(result)
2318}
2319
2320/// Boolean anti-join for composite (multi-column) join keys.
2321///
2322/// Builds a `HashSet<Vec<u64>>` from `neg_facts` using all right-side
2323/// columns in `join_cols`, then filters `batches` to keep only rows
2324/// whose composite left-side key is NOT in the set.
2325pub fn apply_anti_join_composite(
2326    batches: Vec<RecordBatch>,
2327    neg_facts: &[RecordBatch],
2328    join_cols: &[(String, String)],
2329) -> datafusion::error::Result<Vec<RecordBatch>> {
2330    use arrow::compute::filter_record_batch;
2331    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2332
2333    // Collect composite keys from the negated rule's derived facts.
2334    let mut banned: HashSet<Vec<u64>> = HashSet::new();
2335    for batch in neg_facts {
2336        let right_indices: Vec<usize> = join_cols
2337            .iter()
2338            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2339            .collect();
2340        if right_indices.len() != join_cols.len() {
2341            continue;
2342        }
2343        for row in 0..batch.num_rows() {
2344            let mut key = Vec::with_capacity(right_indices.len());
2345            let mut valid = true;
2346            for &ci in &right_indices {
2347                let col = batch.column(ci);
2348                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2349                    if vids.is_null(row) {
2350                        valid = false;
2351                        break;
2352                    }
2353                    key.push(vids.value(row));
2354                } else {
2355                    valid = false;
2356                    break;
2357                }
2358            }
2359            if valid {
2360                banned.insert(key);
2361            }
2362        }
2363    }
2364
2365    if banned.is_empty() {
2366        return Ok(batches);
2367    }
2368
2369    // Filter body batches: keep rows where composite left key NOT IN banned.
2370    let mut result = Vec::new();
2371    for batch in batches {
2372        let left_indices: Vec<usize> = join_cols
2373            .iter()
2374            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2375            .collect();
2376        if left_indices.len() != join_cols.len() {
2377            result.push(batch);
2378            continue;
2379        }
2380        let all_u64 = left_indices.iter().all(|&ci| {
2381            batch
2382                .column(ci)
2383                .as_any()
2384                .downcast_ref::<UInt64Array>()
2385                .is_some()
2386        });
2387        if !all_u64 {
2388            result.push(batch);
2389            continue;
2390        }
2391
2392        let keep: Vec<bool> = (0..batch.num_rows())
2393            .map(|row| {
2394                let mut key = Vec::with_capacity(left_indices.len());
2395                for &ci in &left_indices {
2396                    let vids = batch
2397                        .column(ci)
2398                        .as_any()
2399                        .downcast_ref::<UInt64Array>()
2400                        .unwrap();
2401                    if vids.is_null(row) {
2402                        return true; // null keys are never banned
2403                    }
2404                    key.push(vids.value(row));
2405                }
2406                !banned.contains(&key)
2407            })
2408            .collect();
2409        let keep_arr = BooleanArray::from(keep);
2410        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2411        if filtered.num_rows() > 0 {
2412            result.push(filtered);
2413        }
2414    }
2415    Ok(result)
2416}
2417
2418/// Multiply `__prob_complement_*` columns into the rule's PROB column and clean up.
2419///
2420/// After IS NOT probabilistic complement semantics have added `__prob_complement_*`
2421/// columns to clause results, this function:
2422/// 1. Computes the product of all complement factor columns
2423/// 2. Multiplies the product into the existing PROB column (if any)
2424/// 3. Removes the internal `__prob_complement_*` columns from the output
2425///
2426/// If the rule has no PROB column, complement columns are simply removed
2427/// (the complement information is discarded and IS NOT acts as a keep-all).
2428pub fn multiply_prob_factors(
2429    batches: Vec<RecordBatch>,
2430    prob_col: Option<&str>,
2431    complement_cols: &[String],
2432) -> datafusion::error::Result<Vec<RecordBatch>> {
2433    use arrow_array::{Array as _, Float64Array};
2434
2435    let mut result = Vec::with_capacity(batches.len());
2436
2437    for batch in batches {
2438        if batch.num_rows() == 0 {
2439            // Remove complement columns from empty batches
2440            let keep: Vec<usize> = batch
2441                .schema()
2442                .fields()
2443                .iter()
2444                .enumerate()
2445                .filter(|(_, f)| !complement_cols.contains(f.name()))
2446                .map(|(i, _)| i)
2447                .collect();
2448            let fields: Vec<_> = keep
2449                .iter()
2450                .map(|&i| batch.schema().field(i).clone())
2451                .collect();
2452            let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
2453            let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2454            result.push(
2455                RecordBatch::try_new(schema, cols).map_err(|e| {
2456                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
2457                })?,
2458            );
2459            continue;
2460        }
2461
2462        let num_rows = batch.num_rows();
2463
2464        // 1. Compute product of all complement factors
2465        let mut combined = vec![1.0f64; num_rows];
2466        for col_name in complement_cols {
2467            if let Ok(idx) = batch.schema().index_of(col_name) {
2468                let arr = batch
2469                    .column(idx)
2470                    .as_any()
2471                    .downcast_ref::<Float64Array>()
2472                    .ok_or_else(|| {
2473                        datafusion::error::DataFusionError::Internal(format!(
2474                            "Expected Float64 for complement column {col_name}"
2475                        ))
2476                    })?;
2477                for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
2478                    if !arr.is_null(i) {
2479                        *val *= arr.value(i);
2480                    }
2481                }
2482            }
2483        }
2484
2485        // 2. If there's a PROB column, multiply combined into it
2486        let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
2487            if let Ok(idx) = batch.schema().index_of(prob_name) {
2488                let arr = batch
2489                    .column(idx)
2490                    .as_any()
2491                    .downcast_ref::<Float64Array>()
2492                    .ok_or_else(|| {
2493                        datafusion::error::DataFusionError::Internal(format!(
2494                            "Expected Float64 for PROB column {prob_name}"
2495                        ))
2496                    })?;
2497                (0..num_rows)
2498                    .map(|i| {
2499                        if arr.is_null(i) {
2500                            combined[i]
2501                        } else {
2502                            arr.value(i) * combined[i]
2503                        }
2504                    })
2505                    .collect()
2506            } else {
2507                combined
2508            }
2509        } else {
2510            combined
2511        };
2512
2513        let new_prob_array: arrow_array::ArrayRef =
2514            std::sync::Arc::new(Float64Array::from(final_prob));
2515
2516        // 3. Build output: replace PROB column, remove complement columns
2517        let mut fields = Vec::new();
2518        let mut columns = Vec::new();
2519
2520        for (idx, field) in batch.schema().fields().iter().enumerate() {
2521            if complement_cols.contains(field.name()) {
2522                continue;
2523            }
2524            if prob_col.is_some_and(|p| field.name() == p) {
2525                fields.push(field.clone());
2526                columns.push(new_prob_array.clone());
2527            } else {
2528                fields.push(field.clone());
2529                columns.push(batch.column(idx).clone());
2530            }
2531        }
2532
2533        let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2534        result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
2535    }
2536
2537    Ok(result)
2538}
2539
2540/// Update derived scan handles before evaluating a rule's clause bodies.
2541///
2542/// For self-references: inject delta (semi-naive optimization).
2543/// For cross-references: inject full facts.
2544fn update_derived_scan_handles(
2545    registry: &DerivedScanRegistry,
2546    states: &[FixpointState],
2547    current_rule_idx: usize,
2548    rules: &[FixpointRulePlan],
2549) {
2550    let current_rule_name = &rules[current_rule_idx].name;
2551
2552    for entry in &registry.entries {
2553        // Find the state for this entry's rule
2554        let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
2555        let Some(source_idx) = source_state_idx else {
2556            continue;
2557        };
2558
2559        let is_self = entry.rule_name == *current_rule_name;
2560        let data = if is_self {
2561            // Self-ref: inject delta for semi-naive
2562            states[source_idx].all_delta().to_vec()
2563        } else {
2564            // Cross-ref: inject full facts
2565            states[source_idx].all_facts().to_vec()
2566        };
2567
2568        // If empty, write an empty batch so the scan returns zero rows
2569        let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
2570            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
2571        } else {
2572            data
2573        };
2574
2575        let mut guard = entry.data.write();
2576        *guard = data;
2577    }
2578}
2579
2580// ---------------------------------------------------------------------------
2581// DerivedScanExec — physical plan that reads from shared data at execution time
2582// ---------------------------------------------------------------------------
2583
2584/// Physical plan for `LocyDerivedScan` that reads from a shared `Arc<RwLock>` at
2585/// execution time (not at plan creation time).
2586///
2587/// This is critical for fixpoint iteration: the data handle is updated between
2588/// iterations, and each re-execution of the subplan must read the latest data.
2589pub struct DerivedScanExec {
2590    data: Arc<RwLock<Vec<RecordBatch>>>,
2591    schema: SchemaRef,
2592    properties: PlanProperties,
2593}
2594
2595impl DerivedScanExec {
2596    pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
2597        let properties = compute_plan_properties(Arc::clone(&schema));
2598        Self {
2599            data,
2600            schema,
2601            properties,
2602        }
2603    }
2604}
2605
2606impl fmt::Debug for DerivedScanExec {
2607    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2608        f.debug_struct("DerivedScanExec")
2609            .field("schema", &self.schema)
2610            .finish()
2611    }
2612}
2613
2614impl DisplayAs for DerivedScanExec {
2615    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2616        write!(f, "DerivedScanExec")
2617    }
2618}
2619
2620impl ExecutionPlan for DerivedScanExec {
2621    fn name(&self) -> &str {
2622        "DerivedScanExec"
2623    }
2624    fn as_any(&self) -> &dyn Any {
2625        self
2626    }
2627    fn schema(&self) -> SchemaRef {
2628        Arc::clone(&self.schema)
2629    }
2630    fn properties(&self) -> &PlanProperties {
2631        &self.properties
2632    }
2633    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2634        vec![]
2635    }
2636    fn with_new_children(
2637        self: Arc<Self>,
2638        _children: Vec<Arc<dyn ExecutionPlan>>,
2639    ) -> DFResult<Arc<dyn ExecutionPlan>> {
2640        Ok(self)
2641    }
2642    fn execute(
2643        &self,
2644        _partition: usize,
2645        _context: Arc<TaskContext>,
2646    ) -> DFResult<SendableRecordBatchStream> {
2647        let batches = {
2648            let guard = self.data.read();
2649            if guard.is_empty() {
2650                vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
2651            } else {
2652                guard.clone()
2653            }
2654        };
2655        Ok(Box::pin(MemoryStream::try_new(
2656            batches,
2657            Arc::clone(&self.schema),
2658            None,
2659        )?))
2660    }
2661}
2662
2663// ---------------------------------------------------------------------------
2664// InMemoryExec — wrapper to feed Vec<RecordBatch> into operator chains
2665// ---------------------------------------------------------------------------
2666
2667/// Simple in-memory execution plan that serves pre-computed batches.
2668///
2669/// Used internally to feed fixpoint results into post-fixpoint operator chains
2670/// (FOLD, BEST BY). Not exported — only used within this module.
2671struct InMemoryExec {
2672    batches: Vec<RecordBatch>,
2673    schema: SchemaRef,
2674    properties: PlanProperties,
2675}
2676
2677impl InMemoryExec {
2678    fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
2679        let properties = compute_plan_properties(Arc::clone(&schema));
2680        Self {
2681            batches,
2682            schema,
2683            properties,
2684        }
2685    }
2686}
2687
2688impl fmt::Debug for InMemoryExec {
2689    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2690        f.debug_struct("InMemoryExec")
2691            .field("num_batches", &self.batches.len())
2692            .field("schema", &self.schema)
2693            .finish()
2694    }
2695}
2696
2697impl DisplayAs for InMemoryExec {
2698    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2699        write!(f, "InMemoryExec: batches={}", self.batches.len())
2700    }
2701}
2702
2703impl ExecutionPlan for InMemoryExec {
2704    fn name(&self) -> &str {
2705        "InMemoryExec"
2706    }
2707    fn as_any(&self) -> &dyn Any {
2708        self
2709    }
2710    fn schema(&self) -> SchemaRef {
2711        Arc::clone(&self.schema)
2712    }
2713    fn properties(&self) -> &PlanProperties {
2714        &self.properties
2715    }
2716    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2717        vec![]
2718    }
2719    fn with_new_children(
2720        self: Arc<Self>,
2721        _children: Vec<Arc<dyn ExecutionPlan>>,
2722    ) -> DFResult<Arc<dyn ExecutionPlan>> {
2723        Ok(self)
2724    }
2725    fn execute(
2726        &self,
2727        _partition: usize,
2728        _context: Arc<TaskContext>,
2729    ) -> DFResult<SendableRecordBatchStream> {
2730        Ok(Box::pin(MemoryStream::try_new(
2731            self.batches.clone(),
2732            Arc::clone(&self.schema),
2733            None,
2734        )?))
2735    }
2736}
2737
2738// ---------------------------------------------------------------------------
2739// Post-fixpoint chain — FOLD and BEST BY on converged facts
2740// ---------------------------------------------------------------------------
2741
2742/// Apply post-FOLD WHERE (HAVING) filter to aggregated batches.
2743///
2744/// Converts each Cypher HAVING expression to a DataFusion physical expression
2745/// via `cypher_expr_to_df` → type coercion → `create_physical_expr`, evaluates
2746/// against the FOLD output, and keeps only rows where all conditions hold.
2747fn apply_having_filter(
2748    batches: Vec<RecordBatch>,
2749    having_exprs: &[Expr],
2750    schema: &SchemaRef,
2751) -> DFResult<Vec<RecordBatch>> {
2752    use arrow::compute::{and, filter_record_batch};
2753    use arrow_array::BooleanArray;
2754    use datafusion::common::DFSchema;
2755    use datafusion::logical_expr::LogicalPlanBuilder;
2756    use datafusion::optimizer::AnalyzerRule;
2757    use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
2758    use datafusion::physical_expr::create_physical_expr;
2759    use datafusion::prelude::SessionContext;
2760
2761    if batches.is_empty() {
2762        return Ok(batches);
2763    }
2764
2765    // Build DFSchema from the FOLD output Arrow schema.
2766    let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
2767        datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
2768    })?;
2769
2770    let ctx = SessionContext::new();
2771    let state = ctx.state();
2772    let config = state.config_options().clone();
2773    let props = state.execution_props();
2774
2775    // Cypher Expr → DataFusion DfExpr → type-coerced DfExpr → PhysicalExpr.
2776    //
2777    // Type coercion is needed because FOLD aggregates produce Float64 (SUM,
2778    // AVG) or Int64 (COUNT), and literal comparisons like `total >= 100`
2779    // may mix Float64 columns with Int64 literals.
2780    let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
2781        .iter()
2782        .map(|expr| {
2783            let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
2784                datafusion::common::DataFusionError::Internal(format!(
2785                    "HAVING expression conversion: {e}"
2786                ))
2787            })?;
2788
2789            // Run DataFusion's type coercion by wrapping in a Filter plan,
2790            // applying the TypeCoercion analyzer rule, then extracting the
2791            // coerced predicate.
2792            let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
2793                datafusion::logical_expr::EmptyRelation {
2794                    produce_one_row: false,
2795                    schema: Arc::new(df_schema.clone()),
2796                },
2797            );
2798            let filter_plan = LogicalPlanBuilder::from(empty)
2799                .filter(df_expr.clone())?
2800                .build()?;
2801            let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
2802                Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
2803                _ => df_expr,
2804            };
2805
2806            create_physical_expr(&coerced_expr, &df_schema, props)
2807        })
2808        .collect::<DFResult<Vec<_>>>()?;
2809
2810    let mut result = Vec::new();
2811    for batch in batches {
2812        // Evaluate each condition and AND the boolean masks.
2813        let mut mask: Option<BooleanArray> = None;
2814        for phys_expr in &physical_exprs {
2815            let value = phys_expr.evaluate(&batch)?;
2816            let arr = value.into_array(batch.num_rows())?;
2817            let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
2818                datafusion::common::DataFusionError::Internal(
2819                    "HAVING condition must evaluate to boolean".into(),
2820                )
2821            })?;
2822            mask = Some(match mask {
2823                None => bool_arr.clone(),
2824                Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
2825            });
2826        }
2827        if let Some(ref m) = mask {
2828            let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
2829            if filtered.num_rows() > 0 {
2830                result.push(filtered);
2831            }
2832        } else {
2833            result.push(batch);
2834        }
2835    }
2836    Ok(result)
2837}
2838
2839/// Apply post-fixpoint operators (FOLD, HAVING, BEST BY, PRIORITY) to converged facts.
2840pub(crate) async fn apply_post_fixpoint_chain(
2841    facts: Vec<RecordBatch>,
2842    rule: &FixpointRulePlan,
2843    task_ctx: &Arc<TaskContext>,
2844    strict_probability_domain: bool,
2845    probability_epsilon: f64,
2846) -> DFResult<Vec<RecordBatch>> {
2847    if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
2848        return Ok(facts);
2849    }
2850
2851    // Wrap facts in InMemoryExec.
2852    // Prefer the actual batch schema (from physical execution) over the
2853    // pre-computed yield_schema, which may have wrong inferred types
2854    // (e.g. Float64 for a string property).
2855    let schema = facts
2856        .iter()
2857        .find(|b| b.num_rows() > 0)
2858        .map(|b| b.schema())
2859        .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
2860    let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
2861
2862    // Reconcile key indices: rule's indices are yield-schema positions but
2863    // the actual batch may have different column ordering after schema
2864    // reconciliation during fixpoint iteration (same pattern as
2865    // FixpointState::reconcile_schema).
2866    let key_column_indices: Vec<usize> = rule
2867        .key_column_indices
2868        .iter()
2869        .filter_map(|&i| {
2870            let name = rule.yield_schema.field(i).name();
2871            schema.index_of(name).ok()
2872        })
2873        .collect();
2874
2875    // Apply PRIORITY first — keeps only rows with max __priority per KEY group,
2876    // then strips the __priority column from output.
2877    // Must run before FOLD so that the __priority column is still present.
2878    let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
2879        let priority_schema = input.schema();
2880        let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
2881            datafusion::common::DataFusionError::Internal(
2882                "PRIORITY rule missing __priority column".to_string(),
2883            )
2884        })?;
2885        Arc::new(PriorityExec::new(
2886            input,
2887            key_column_indices.clone(),
2888            priority_idx,
2889        ))
2890    } else {
2891        input
2892    };
2893
2894    // Apply FOLD
2895    let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
2896        Arc::new(FoldExec::new(
2897            current,
2898            key_column_indices.clone(),
2899            rule.fold_bindings.clone(),
2900            strict_probability_domain,
2901            probability_epsilon,
2902        ))
2903    } else {
2904        current
2905    };
2906
2907    // Apply HAVING (post-FOLD WHERE filter)
2908    let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
2909        let batches = collect_all_partitions(&current, Arc::clone(task_ctx)).await?;
2910        let filtered = apply_having_filter(batches, &rule.having, &current.schema())?;
2911        if filtered.is_empty() {
2912            return Ok(filtered);
2913        }
2914        Arc::new(InMemoryExec::new(filtered, Arc::clone(&current.schema())))
2915    } else {
2916        current
2917    };
2918
2919    // Apply BEST BY
2920    let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
2921        Arc::new(BestByExec::new(
2922            current,
2923            key_column_indices.clone(),
2924            rule.best_by_criteria.clone(),
2925            rule.deterministic,
2926        ))
2927    } else {
2928        current
2929    };
2930
2931    collect_all_partitions(&current, Arc::clone(task_ctx)).await
2932}
2933
2934// ---------------------------------------------------------------------------
2935// FixpointExec — DataFusion ExecutionPlan
2936// ---------------------------------------------------------------------------
2937
2938/// DataFusion `ExecutionPlan` that drives semi-naive fixpoint iteration.
2939///
2940/// Has no physical children: clause bodies are re-planned from logical plans
2941/// on each iteration (same pattern as `RecursiveCTEExec` and `GraphApplyExec`).
2942pub struct FixpointExec {
2943    rules: Vec<FixpointRulePlan>,
2944    max_iterations: usize,
2945    timeout: Duration,
2946    graph_ctx: Arc<GraphExecutionContext>,
2947    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2948    storage: Arc<StorageManager>,
2949    schema_info: Arc<UniSchema>,
2950    params: HashMap<String, Value>,
2951    derived_scan_registry: Arc<DerivedScanRegistry>,
2952    output_schema: SchemaRef,
2953    properties: PlanProperties,
2954    metrics: ExecutionPlanMetricsSet,
2955    max_derived_bytes: usize,
2956    /// Optional provenance tracker populated during fixpoint iteration.
2957    derivation_tracker: Option<Arc<ProvenanceStore>>,
2958    /// Shared slot written with per-rule iteration counts after convergence.
2959    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2960    strict_probability_domain: bool,
2961    probability_epsilon: f64,
2962    exact_probability: bool,
2963    max_bdd_variables: usize,
2964    /// Shared slot for runtime warnings collected during fixpoint iteration.
2965    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2966    /// Shared slot for groups where BDD fell back to independence mode.
2967    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2968    /// When > 0, retain at most this many proofs per fact (top-k provenance).
2969    top_k_proofs: usize,
2970    /// Shared flag: set to true on timeout to signal partial results.
2971    timeout_flag: Arc<std::sync::atomic::AtomicBool>,
2972}
2973
2974impl fmt::Debug for FixpointExec {
2975    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2976        f.debug_struct("FixpointExec")
2977            .field("rules_count", &self.rules.len())
2978            .field("max_iterations", &self.max_iterations)
2979            .field("timeout", &self.timeout)
2980            .field("output_schema", &self.output_schema)
2981            .field("max_derived_bytes", &self.max_derived_bytes)
2982            .finish_non_exhaustive()
2983    }
2984}
2985
2986impl FixpointExec {
2987    /// Create a new `FixpointExec`.
2988    #[expect(
2989        clippy::too_many_arguments,
2990        reason = "FixpointExec configuration needs all context"
2991    )]
2992    pub fn new(
2993        rules: Vec<FixpointRulePlan>,
2994        max_iterations: usize,
2995        timeout: Duration,
2996        graph_ctx: Arc<GraphExecutionContext>,
2997        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2998        storage: Arc<StorageManager>,
2999        schema_info: Arc<UniSchema>,
3000        params: HashMap<String, Value>,
3001        derived_scan_registry: Arc<DerivedScanRegistry>,
3002        output_schema: SchemaRef,
3003        max_derived_bytes: usize,
3004        derivation_tracker: Option<Arc<ProvenanceStore>>,
3005        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
3006        strict_probability_domain: bool,
3007        probability_epsilon: f64,
3008        exact_probability: bool,
3009        max_bdd_variables: usize,
3010        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
3011        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
3012        top_k_proofs: usize,
3013        timeout_flag: Arc<std::sync::atomic::AtomicBool>,
3014    ) -> Self {
3015        let properties = compute_plan_properties(Arc::clone(&output_schema));
3016        Self {
3017            rules,
3018            max_iterations,
3019            timeout,
3020            graph_ctx,
3021            session_ctx,
3022            storage,
3023            schema_info,
3024            params,
3025            derived_scan_registry,
3026            output_schema,
3027            properties,
3028            metrics: ExecutionPlanMetricsSet::new(),
3029            max_derived_bytes,
3030            derivation_tracker,
3031            iteration_counts,
3032            strict_probability_domain,
3033            probability_epsilon,
3034            exact_probability,
3035            max_bdd_variables,
3036            warnings_slot,
3037            approximate_slot,
3038            top_k_proofs,
3039            timeout_flag,
3040        }
3041    }
3042
3043    /// Returns the shared iteration counts slot for post-execution inspection.
3044    pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
3045        Arc::clone(&self.iteration_counts)
3046    }
3047}
3048
3049impl DisplayAs for FixpointExec {
3050    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3051        write!(
3052            f,
3053            "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
3054            self.rules
3055                .iter()
3056                .map(|r| r.name.as_str())
3057                .collect::<Vec<_>>()
3058                .join(", "),
3059            self.max_iterations,
3060            self.timeout,
3061        )
3062    }
3063}
3064
3065impl ExecutionPlan for FixpointExec {
3066    fn name(&self) -> &str {
3067        "FixpointExec"
3068    }
3069
3070    fn as_any(&self) -> &dyn Any {
3071        self
3072    }
3073
3074    fn schema(&self) -> SchemaRef {
3075        Arc::clone(&self.output_schema)
3076    }
3077
3078    fn properties(&self) -> &PlanProperties {
3079        &self.properties
3080    }
3081
3082    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
3083        // No physical children — clause bodies are re-planned each iteration
3084        vec![]
3085    }
3086
3087    fn with_new_children(
3088        self: Arc<Self>,
3089        children: Vec<Arc<dyn ExecutionPlan>>,
3090    ) -> DFResult<Arc<dyn ExecutionPlan>> {
3091        if !children.is_empty() {
3092            return Err(datafusion::error::DataFusionError::Plan(
3093                "FixpointExec has no children".to_string(),
3094            ));
3095        }
3096        Ok(self)
3097    }
3098
3099    fn execute(
3100        &self,
3101        partition: usize,
3102        _context: Arc<TaskContext>,
3103    ) -> DFResult<SendableRecordBatchStream> {
3104        let metrics = BaselineMetrics::new(&self.metrics, partition);
3105
3106        // Clone all fields for the async closure
3107        let rules = self
3108            .rules
3109            .iter()
3110            .map(|r| {
3111                // We need to clone the FixpointRulePlan, but it contains LogicalPlan
3112                // which doesn't implement Clone traditionally. However, our LogicalPlan
3113                // does implement Clone since it's an enum.
3114                FixpointRulePlan {
3115                    name: r.name.clone(),
3116                    clauses: r
3117                        .clauses
3118                        .iter()
3119                        .map(|c| FixpointClausePlan {
3120                            body_logical: c.body_logical.clone(),
3121                            is_ref_bindings: c.is_ref_bindings.clone(),
3122                            priority: c.priority,
3123                            along_bindings: c.along_bindings.clone(),
3124                        })
3125                        .collect(),
3126                    yield_schema: Arc::clone(&r.yield_schema),
3127                    key_column_indices: r.key_column_indices.clone(),
3128                    priority: r.priority,
3129                    has_fold: r.has_fold,
3130                    fold_bindings: r.fold_bindings.clone(),
3131                    having: r.having.clone(),
3132                    has_best_by: r.has_best_by,
3133                    best_by_criteria: r.best_by_criteria.clone(),
3134                    has_priority: r.has_priority,
3135                    deterministic: r.deterministic,
3136                    prob_column_name: r.prob_column_name.clone(),
3137                }
3138            })
3139            .collect();
3140
3141        let max_iterations = self.max_iterations;
3142        let timeout = self.timeout;
3143        let graph_ctx = Arc::clone(&self.graph_ctx);
3144        let session_ctx = Arc::clone(&self.session_ctx);
3145        let storage = Arc::clone(&self.storage);
3146        let schema_info = Arc::clone(&self.schema_info);
3147        let params = self.params.clone();
3148        let registry = Arc::clone(&self.derived_scan_registry);
3149        let output_schema = Arc::clone(&self.output_schema);
3150        let max_derived_bytes = self.max_derived_bytes;
3151        let derivation_tracker = self.derivation_tracker.clone();
3152        let iteration_counts = Arc::clone(&self.iteration_counts);
3153        let strict_probability_domain = self.strict_probability_domain;
3154        let probability_epsilon = self.probability_epsilon;
3155        let exact_probability = self.exact_probability;
3156        let max_bdd_variables = self.max_bdd_variables;
3157        let warnings_slot = Arc::clone(&self.warnings_slot);
3158        let approximate_slot = Arc::clone(&self.approximate_slot);
3159        let top_k_proofs = self.top_k_proofs;
3160        let timeout_flag = Arc::clone(&self.timeout_flag);
3161
3162        let fut = async move {
3163            run_fixpoint_loop(
3164                rules,
3165                max_iterations,
3166                timeout,
3167                graph_ctx,
3168                session_ctx,
3169                storage,
3170                schema_info,
3171                params,
3172                registry,
3173                output_schema,
3174                max_derived_bytes,
3175                derivation_tracker,
3176                iteration_counts,
3177                strict_probability_domain,
3178                probability_epsilon,
3179                exact_probability,
3180                max_bdd_variables,
3181                warnings_slot,
3182                approximate_slot,
3183                top_k_proofs,
3184                timeout_flag,
3185            )
3186            .await
3187        };
3188
3189        Ok(Box::pin(FixpointStream {
3190            state: FixpointStreamState::Running(Box::pin(fut)),
3191            schema: Arc::clone(&self.output_schema),
3192            metrics,
3193        }))
3194    }
3195
3196    fn metrics(&self) -> Option<MetricsSet> {
3197        Some(self.metrics.clone_inner())
3198    }
3199}
3200
3201// ---------------------------------------------------------------------------
3202// FixpointStream — async state machine for streaming results
3203// ---------------------------------------------------------------------------
3204
3205enum FixpointStreamState {
3206    /// Fixpoint loop is running.
3207    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
3208    /// Emitting accumulated result batches one at a time.
3209    Emitting(Vec<RecordBatch>, usize),
3210    /// All batches emitted.
3211    Done,
3212}
3213
3214struct FixpointStream {
3215    state: FixpointStreamState,
3216    schema: SchemaRef,
3217    metrics: BaselineMetrics,
3218}
3219
3220impl Stream for FixpointStream {
3221    type Item = DFResult<RecordBatch>;
3222
3223    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3224        let this = self.get_mut();
3225        loop {
3226            match &mut this.state {
3227                FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
3228                    Poll::Ready(Ok(batches)) => {
3229                        if batches.is_empty() {
3230                            this.state = FixpointStreamState::Done;
3231                            return Poll::Ready(None);
3232                        }
3233                        this.state = FixpointStreamState::Emitting(batches, 0);
3234                        // Loop to emit first batch
3235                    }
3236                    Poll::Ready(Err(e)) => {
3237                        this.state = FixpointStreamState::Done;
3238                        return Poll::Ready(Some(Err(e)));
3239                    }
3240                    Poll::Pending => return Poll::Pending,
3241                },
3242                FixpointStreamState::Emitting(batches, idx) => {
3243                    if *idx >= batches.len() {
3244                        this.state = FixpointStreamState::Done;
3245                        return Poll::Ready(None);
3246                    }
3247                    let batch = batches[*idx].clone();
3248                    *idx += 1;
3249                    this.metrics.record_output(batch.num_rows());
3250                    return Poll::Ready(Some(Ok(batch)));
3251                }
3252                FixpointStreamState::Done => return Poll::Ready(None),
3253            }
3254        }
3255    }
3256}
3257
3258impl RecordBatchStream for FixpointStream {
3259    fn schema(&self) -> SchemaRef {
3260        Arc::clone(&self.schema)
3261    }
3262}
3263
3264// ---------------------------------------------------------------------------
3265// Unit tests
3266// ---------------------------------------------------------------------------
3267
3268#[cfg(test)]
3269mod tests {
3270    use super::*;
3271    use arrow_array::{Float64Array, Int64Array, StringArray};
3272    use arrow_schema::{DataType, Field, Schema};
3273
3274    fn test_schema() -> SchemaRef {
3275        Arc::new(Schema::new(vec![
3276            Field::new("name", DataType::Utf8, true),
3277            Field::new("value", DataType::Int64, true),
3278        ]))
3279    }
3280
3281    fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
3282        RecordBatch::try_new(
3283            test_schema(),
3284            vec![
3285                Arc::new(StringArray::from(
3286                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3287                )),
3288                Arc::new(Int64Array::from(values.to_vec())),
3289            ],
3290        )
3291        .unwrap()
3292    }
3293
3294    // --- FixpointState dedup tests ---
3295
3296    #[tokio::test]
3297    async fn test_fixpoint_state_empty_facts_adds_all() {
3298        let schema = test_schema();
3299        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3300
3301        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3302        let changed = state.merge_delta(vec![batch], None).await.unwrap();
3303
3304        assert!(changed);
3305        assert_eq!(state.all_facts().len(), 1);
3306        assert_eq!(state.all_facts()[0].num_rows(), 3);
3307        assert_eq!(state.all_delta().len(), 1);
3308        assert_eq!(state.all_delta()[0].num_rows(), 3);
3309    }
3310
3311    #[tokio::test]
3312    async fn test_fixpoint_state_exact_duplicates_excluded() {
3313        let schema = test_schema();
3314        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3315
3316        let batch1 = make_batch(&["a", "b"], &[1, 2]);
3317        state.merge_delta(vec![batch1], None).await.unwrap();
3318
3319        // Same rows again
3320        let batch2 = make_batch(&["a", "b"], &[1, 2]);
3321        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3322        assert!(!changed);
3323        assert!(
3324            state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
3325        );
3326    }
3327
3328    #[tokio::test]
3329    async fn test_fixpoint_state_partial_overlap() {
3330        let schema = test_schema();
3331        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3332
3333        let batch1 = make_batch(&["a", "b"], &[1, 2]);
3334        state.merge_delta(vec![batch1], None).await.unwrap();
3335
3336        // "a":1 is duplicate, "c":3 is new
3337        let batch2 = make_batch(&["a", "c"], &[1, 3]);
3338        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3339        assert!(changed);
3340
3341        // Delta should have only "c":3
3342        let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
3343        assert_eq!(delta_rows, 1);
3344
3345        // Total facts: a:1, b:2, c:3
3346        let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
3347        assert_eq!(total_rows, 3);
3348    }
3349
3350    #[tokio::test]
3351    async fn test_fixpoint_state_convergence() {
3352        let schema = test_schema();
3353        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3354
3355        let batch = make_batch(&["a"], &[1]);
3356        state.merge_delta(vec![batch], None).await.unwrap();
3357
3358        // Empty candidates → converged
3359        let changed = state.merge_delta(vec![], None).await.unwrap();
3360        assert!(!changed);
3361        assert!(state.is_converged());
3362    }
3363
3364    // --- RowDedupState tests ---
3365
3366    #[test]
3367    fn test_row_dedup_persistent_across_calls() {
3368        // RowDedupState should remember rows from the first call so the second
3369        // call does not re-accept them (O(M) per iteration, no facts re-scan).
3370        let schema = test_schema();
3371        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3372
3373        let batch1 = make_batch(&["a", "b"], &[1, 2]);
3374        let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
3375        // First call: both rows are new.
3376        let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
3377        assert_eq!(rows1, 2);
3378
3379        // Second call with same rows: seen set already has them → empty delta.
3380        let batch2 = make_batch(&["a", "b"], &[1, 2]);
3381        let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
3382        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3383        assert_eq!(rows2, 0);
3384
3385        // Third call with one old + one new: only the new row is returned.
3386        let batch3 = make_batch(&["a", "c"], &[1, 3]);
3387        let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
3388        let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
3389        assert_eq!(rows3, 1);
3390    }
3391
3392    #[test]
3393    fn test_row_dedup_null_handling() {
3394        use arrow_array::StringArray;
3395        use arrow_schema::{DataType, Field, Schema};
3396
3397        let schema: SchemaRef = Arc::new(Schema::new(vec![
3398            Field::new("a", DataType::Utf8, true),
3399            Field::new("b", DataType::Int64, true),
3400        ]));
3401        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3402
3403        // Two rows: (NULL, 1) and (NULL, 1) — same NULLs → duplicate.
3404        let batch_nulls = RecordBatch::try_new(
3405            Arc::clone(&schema),
3406            vec![
3407                Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
3408                Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
3409            ],
3410        )
3411        .unwrap();
3412        let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
3413        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3414        assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
3415
3416        // (NULL, 2) — NULL in same col but different non-null col → distinct.
3417        let batch_diff = RecordBatch::try_new(
3418            Arc::clone(&schema),
3419            vec![
3420                Arc::new(StringArray::from(vec![None::<&str>])),
3421                Arc::new(arrow_array::Int64Array::from(vec![2i64])),
3422            ],
3423        )
3424        .unwrap();
3425        let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
3426        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3427        assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
3428    }
3429
3430    #[test]
3431    fn test_row_dedup_within_candidate_dedup() {
3432        // Duplicate rows within a single candidate batch should be collapsed to one.
3433        let schema = test_schema();
3434        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3435
3436        // Batch with three rows: a:1, a:1, b:2 — "a:1" appears twice.
3437        let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
3438        let delta = rd.compute_delta(&[batch], &schema).unwrap();
3439        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3440        assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
3441    }
3442
3443    // --- Float rounding tests ---
3444
3445    #[test]
3446    fn test_round_float_columns_near_duplicates() {
3447        let schema = Arc::new(Schema::new(vec![
3448            Field::new("name", DataType::Utf8, true),
3449            Field::new("dist", DataType::Float64, true),
3450        ]));
3451        let batch = RecordBatch::try_new(
3452            Arc::clone(&schema),
3453            vec![
3454                Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
3455                Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
3456            ],
3457        )
3458        .unwrap();
3459
3460        let rounded = round_float_columns(&[batch]);
3461        assert_eq!(rounded.len(), 1);
3462        let col = rounded[0]
3463            .column(1)
3464            .as_any()
3465            .downcast_ref::<Float64Array>()
3466            .unwrap();
3467        // Both should round to same value
3468        assert_eq!(col.value(0), col.value(1));
3469    }
3470
3471    // --- DerivedScanRegistry tests ---
3472
3473    #[test]
3474    fn test_registry_write_read_round_trip() {
3475        let schema = test_schema();
3476        let data = Arc::new(RwLock::new(Vec::new()));
3477        let mut reg = DerivedScanRegistry::new();
3478        reg.add(DerivedScanEntry {
3479            scan_index: 0,
3480            rule_name: "reachable".into(),
3481            is_self_ref: true,
3482            data: Arc::clone(&data),
3483            schema: Arc::clone(&schema),
3484        });
3485
3486        let batch = make_batch(&["x"], &[42]);
3487        reg.write_data(0, vec![batch.clone()]);
3488
3489        let entry = reg.get(0).unwrap();
3490        let guard = entry.data.read();
3491        assert_eq!(guard.len(), 1);
3492        assert_eq!(guard[0].num_rows(), 1);
3493    }
3494
3495    #[test]
3496    fn test_registry_entries_for_rule() {
3497        let schema = test_schema();
3498        let mut reg = DerivedScanRegistry::new();
3499        reg.add(DerivedScanEntry {
3500            scan_index: 0,
3501            rule_name: "r1".into(),
3502            is_self_ref: true,
3503            data: Arc::new(RwLock::new(Vec::new())),
3504            schema: Arc::clone(&schema),
3505        });
3506        reg.add(DerivedScanEntry {
3507            scan_index: 1,
3508            rule_name: "r2".into(),
3509            is_self_ref: false,
3510            data: Arc::new(RwLock::new(Vec::new())),
3511            schema: Arc::clone(&schema),
3512        });
3513        reg.add(DerivedScanEntry {
3514            scan_index: 2,
3515            rule_name: "r1".into(),
3516            is_self_ref: false,
3517            data: Arc::new(RwLock::new(Vec::new())),
3518            schema: Arc::clone(&schema),
3519        });
3520
3521        assert_eq!(reg.entries_for_rule("r1").len(), 2);
3522        assert_eq!(reg.entries_for_rule("r2").len(), 1);
3523        assert_eq!(reg.entries_for_rule("r3").len(), 0);
3524    }
3525
3526    // --- MonotonicAggState tests ---
3527
3528    #[test]
3529    fn test_monotonic_agg_update_and_stability() {
3530        use crate::query::df_graph::locy_fold::FoldAggKind;
3531
3532        let bindings = vec![MonotonicFoldBinding {
3533            fold_name: "total".into(),
3534            kind: FoldAggKind::Sum,
3535            input_col_index: 1,
3536            input_col_name: None,
3537        }];
3538        let mut agg = MonotonicAggState::new(bindings);
3539
3540        // First update
3541        let batch = make_batch(&["a"], &[10]);
3542        agg.snapshot();
3543        let changed = agg.update(&[0], &[batch], false).unwrap();
3544        assert!(changed);
3545        assert!(!agg.is_stable()); // changed since snapshot
3546
3547        // Snapshot and check stability with no new data
3548        agg.snapshot();
3549        let changed = agg.update(&[0], &[], false).unwrap();
3550        assert!(!changed);
3551        assert!(agg.is_stable());
3552    }
3553
3554    // --- Memory limit test ---
3555
3556    #[tokio::test]
3557    async fn test_memory_limit_exceeded() {
3558        let schema = test_schema();
3559        // Set a tiny limit
3560        let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
3561
3562        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3563        let result = state.merge_delta(vec![batch], None).await;
3564        assert!(result.is_err());
3565        let err = result.unwrap_err().to_string();
3566        assert!(err.contains("memory limit"), "Error was: {}", err);
3567    }
3568
3569    // --- FixpointStream lifecycle test ---
3570
3571    #[tokio::test]
3572    async fn test_fixpoint_stream_emitting() {
3573        use futures::StreamExt;
3574
3575        let schema = test_schema();
3576        let batch1 = make_batch(&["a"], &[1]);
3577        let batch2 = make_batch(&["b"], &[2]);
3578
3579        let metrics = ExecutionPlanMetricsSet::new();
3580        let baseline = BaselineMetrics::new(&metrics, 0);
3581
3582        let mut stream = FixpointStream {
3583            state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
3584            schema,
3585            metrics: baseline,
3586        };
3587
3588        let stream = Pin::new(&mut stream);
3589        let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
3590
3591        assert_eq!(batches.len(), 2);
3592        assert_eq!(batches[0].num_rows(), 1);
3593        assert_eq!(batches[1].num_rows(), 1);
3594    }
3595
3596    // ── MonotonicAggState MNOR/MPROD tests ──────────────────────────────
3597
3598    fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
3599        let schema = Arc::new(Schema::new(vec![
3600            Field::new("name", DataType::Utf8, true),
3601            Field::new("value", DataType::Float64, true),
3602        ]));
3603        RecordBatch::try_new(
3604            schema,
3605            vec![
3606                Arc::new(StringArray::from(
3607                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3608                )),
3609                Arc::new(Float64Array::from(values.to_vec())),
3610            ],
3611        )
3612        .unwrap()
3613    }
3614
3615    fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
3616        use crate::query::df_graph::locy_fold::FoldAggKind;
3617        vec![MonotonicFoldBinding {
3618            fold_name: "prob".into(),
3619            kind: FoldAggKind::Nor,
3620            input_col_index: 1,
3621            input_col_name: None,
3622        }]
3623    }
3624
3625    fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
3626        use crate::query::df_graph::locy_fold::FoldAggKind;
3627        vec![MonotonicFoldBinding {
3628            fold_name: "prob".into(),
3629            kind: FoldAggKind::Prod,
3630            input_col_index: 1,
3631            input_col_name: None,
3632        }]
3633    }
3634
3635    fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
3636        (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
3637    }
3638
3639    #[test]
3640    fn test_monotonic_nor_first_update() {
3641        let mut agg = MonotonicAggState::new(make_nor_binding());
3642        let batch = make_f64_batch(&["a"], &[0.3]);
3643        let changed = agg.update(&[0], &[batch], false).unwrap();
3644        assert!(changed);
3645        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3646        assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
3647    }
3648
3649    #[test]
3650    fn test_monotonic_nor_two_updates() {
3651        // Incremental NOR: acc = 1-(1-0.3)(1-0.5) = 0.65
3652        let mut agg = MonotonicAggState::new(make_nor_binding());
3653        let batch1 = make_f64_batch(&["a"], &[0.3]);
3654        agg.update(&[0], &[batch1], false).unwrap();
3655        let batch2 = make_f64_batch(&["a"], &[0.5]);
3656        agg.update(&[0], &[batch2], false).unwrap();
3657        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3658        assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
3659    }
3660
3661    #[test]
3662    fn test_monotonic_prod_first_update() {
3663        let mut agg = MonotonicAggState::new(make_prod_binding());
3664        let batch = make_f64_batch(&["a"], &[0.6]);
3665        let changed = agg.update(&[0], &[batch], false).unwrap();
3666        assert!(changed);
3667        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3668        assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
3669    }
3670
3671    #[test]
3672    fn test_monotonic_prod_two_updates() {
3673        // Incremental PROD: acc = 0.6 * 0.8 = 0.48
3674        let mut agg = MonotonicAggState::new(make_prod_binding());
3675        let batch1 = make_f64_batch(&["a"], &[0.6]);
3676        agg.update(&[0], &[batch1], false).unwrap();
3677        let batch2 = make_f64_batch(&["a"], &[0.8]);
3678        agg.update(&[0], &[batch2], false).unwrap();
3679        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3680        assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
3681    }
3682
3683    #[test]
3684    fn test_monotonic_nor_stability() {
3685        let mut agg = MonotonicAggState::new(make_nor_binding());
3686        let batch = make_f64_batch(&["a"], &[0.3]);
3687        agg.update(&[0], &[batch], false).unwrap();
3688        agg.snapshot();
3689        let changed = agg.update(&[0], &[], false).unwrap();
3690        assert!(!changed);
3691        assert!(agg.is_stable());
3692    }
3693
3694    #[test]
3695    fn test_monotonic_prod_stability() {
3696        let mut agg = MonotonicAggState::new(make_prod_binding());
3697        let batch = make_f64_batch(&["a"], &[0.6]);
3698        agg.update(&[0], &[batch], false).unwrap();
3699        agg.snapshot();
3700        let changed = agg.update(&[0], &[], false).unwrap();
3701        assert!(!changed);
3702        assert!(agg.is_stable());
3703    }
3704
3705    #[test]
3706    fn test_monotonic_nor_multi_group() {
3707        // (a,0.3),(b,0.5) then (a,0.5),(b,0.2) → a=0.65, b=0.6
3708        let mut agg = MonotonicAggState::new(make_nor_binding());
3709        let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
3710        agg.update(&[0], &[batch1], false).unwrap();
3711        let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
3712        agg.update(&[0], &[batch2], false).unwrap();
3713
3714        let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
3715        let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
3716        assert!(
3717            (val_a - 0.65).abs() < 1e-10,
3718            "expected a=0.65, got {}",
3719            val_a
3720        );
3721        assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
3722    }
3723
3724    #[test]
3725    fn test_monotonic_prod_zero_absorbing() {
3726        // Zero absorbs: once 0.0, all further updates stay 0.0
3727        let mut agg = MonotonicAggState::new(make_prod_binding());
3728        let batch1 = make_f64_batch(&["a"], &[0.5]);
3729        agg.update(&[0], &[batch1], false).unwrap();
3730        let batch2 = make_f64_batch(&["a"], &[0.0]);
3731        agg.update(&[0], &[batch2], false).unwrap();
3732
3733        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3734        assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
3735
3736        // Further updates don't change the absorbing zero
3737        agg.snapshot();
3738        let batch3 = make_f64_batch(&["a"], &[0.5]);
3739        let changed = agg.update(&[0], &[batch3], false).unwrap();
3740        assert!(!changed);
3741        assert!(agg.is_stable());
3742    }
3743
3744    #[test]
3745    fn test_monotonic_nor_clamping() {
3746        // 1.5 clamped to 1.0: acc = 1-(1-0)(1-1) = 1.0
3747        let mut agg = MonotonicAggState::new(make_nor_binding());
3748        let batch = make_f64_batch(&["a"], &[1.5]);
3749        agg.update(&[0], &[batch], false).unwrap();
3750        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3751        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3752    }
3753
3754    #[test]
3755    fn test_monotonic_nor_absorbing() {
3756        // p=1.0 absorbs: 0.3 then 1.0 → 1.0
3757        let mut agg = MonotonicAggState::new(make_nor_binding());
3758        let batch1 = make_f64_batch(&["a"], &[0.3]);
3759        agg.update(&[0], &[batch1], false).unwrap();
3760        let batch2 = make_f64_batch(&["a"], &[1.0]);
3761        agg.update(&[0], &[batch2], false).unwrap();
3762        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3763        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3764    }
3765
3766    // ── MonotonicAggState strict mode tests (Phase 5) ─────────────────────
3767
3768    #[test]
3769    fn test_monotonic_agg_strict_nor_rejects() {
3770        let mut agg = MonotonicAggState::new(make_nor_binding());
3771        let batch = make_f64_batch(&["a"], &[1.5]);
3772        let result = agg.update(&[0], &[batch], true);
3773        assert!(result.is_err());
3774        let err = result.unwrap_err().to_string();
3775        assert!(
3776            err.contains("strict_probability_domain"),
3777            "Expected strict error, got: {}",
3778            err
3779        );
3780    }
3781
3782    #[test]
3783    fn test_monotonic_agg_strict_prod_rejects() {
3784        let mut agg = MonotonicAggState::new(make_prod_binding());
3785        let batch = make_f64_batch(&["a"], &[2.0]);
3786        let result = agg.update(&[0], &[batch], true);
3787        assert!(result.is_err());
3788        let err = result.unwrap_err().to_string();
3789        assert!(
3790            err.contains("strict_probability_domain"),
3791            "Expected strict error, got: {}",
3792            err
3793        );
3794    }
3795
3796    #[test]
3797    fn test_monotonic_agg_strict_accepts_valid() {
3798        let mut agg = MonotonicAggState::new(make_nor_binding());
3799        let batch = make_f64_batch(&["a"], &[0.5]);
3800        let result = agg.update(&[0], &[batch], true);
3801        assert!(result.is_ok());
3802        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3803        assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
3804    }
3805
3806    // ── Complement function unit tests (Phase 4) ──────────────────────────
3807
3808    fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
3809        use arrow_array::UInt64Array;
3810        let schema = Arc::new(Schema::new(vec![
3811            Field::new("vid", DataType::UInt64, true),
3812            Field::new("prob", DataType::Float64, true),
3813        ]));
3814        RecordBatch::try_new(
3815            schema,
3816            vec![
3817                Arc::new(UInt64Array::from(vids.to_vec())),
3818                Arc::new(Float64Array::from(probs.to_vec())),
3819            ],
3820        )
3821        .unwrap()
3822    }
3823
3824    #[test]
3825    fn test_prob_complement_basic() {
3826        // neg has VID=1 with prob=0.7 → complement=0.3; VID=2 absent → complement=1.0
3827        let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
3828        let neg = make_vid_prob_batch(&[1], &[0.7]);
3829        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3830        let result = apply_prob_complement_composite(
3831            vec![body],
3832            &[neg],
3833            &join_cols,
3834            "prob",
3835            "__complement_0",
3836        )
3837        .unwrap();
3838        assert_eq!(result.len(), 1);
3839        let batch = &result[0];
3840        let complement = batch
3841            .column_by_name("__complement_0")
3842            .unwrap()
3843            .as_any()
3844            .downcast_ref::<Float64Array>()
3845            .unwrap();
3846        // VID=1: complement = 1 - 0.7 = 0.3
3847        assert!(
3848            (complement.value(0) - 0.3).abs() < 1e-10,
3849            "expected 0.3, got {}",
3850            complement.value(0)
3851        );
3852        // VID=2: absent from neg → complement = 1.0
3853        assert!(
3854            (complement.value(1) - 1.0).abs() < 1e-10,
3855            "expected 1.0, got {}",
3856            complement.value(1)
3857        );
3858    }
3859
3860    #[test]
3861    fn test_prob_complement_noisy_or_duplicates() {
3862        // neg has VID=1 twice with prob=0.3 and prob=0.5
3863        // Combined via noisy-OR: 1-(1-0.3)(1-0.5) = 0.65
3864        // Complement = 1 - 0.65 = 0.35
3865        let body = make_vid_prob_batch(&[1], &[0.9]);
3866        let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
3867        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3868        let result = apply_prob_complement_composite(
3869            vec![body],
3870            &[neg],
3871            &join_cols,
3872            "prob",
3873            "__complement_0",
3874        )
3875        .unwrap();
3876        let batch = &result[0];
3877        let complement = batch
3878            .column_by_name("__complement_0")
3879            .unwrap()
3880            .as_any()
3881            .downcast_ref::<Float64Array>()
3882            .unwrap();
3883        assert!(
3884            (complement.value(0) - 0.35).abs() < 1e-10,
3885            "expected 0.35, got {}",
3886            complement.value(0)
3887        );
3888    }
3889
3890    #[test]
3891    fn test_prob_complement_empty_neg() {
3892        // Empty neg_facts → body passes through with complement=1.0
3893        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3894        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3895        let result =
3896            apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
3897                .unwrap();
3898        let batch = &result[0];
3899        let complement = batch
3900            .column_by_name("__complement_0")
3901            .unwrap()
3902            .as_any()
3903            .downcast_ref::<Float64Array>()
3904            .unwrap();
3905        for i in 0..2 {
3906            assert!(
3907                (complement.value(i) - 1.0).abs() < 1e-10,
3908                "row {}: expected 1.0, got {}",
3909                i,
3910                complement.value(i)
3911            );
3912        }
3913    }
3914
3915    #[test]
3916    fn test_anti_join_basic() {
3917        // body [1,2,3], neg [2] → result [1,3]
3918        use arrow_array::UInt64Array;
3919        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3920        let neg = make_vid_prob_batch(&[2], &[0.0]);
3921        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3922        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3923        assert_eq!(result.len(), 1);
3924        let batch = &result[0];
3925        assert_eq!(batch.num_rows(), 2);
3926        let vids = batch
3927            .column_by_name("vid")
3928            .unwrap()
3929            .as_any()
3930            .downcast_ref::<UInt64Array>()
3931            .unwrap();
3932        assert_eq!(vids.value(0), 1);
3933        assert_eq!(vids.value(1), 3);
3934    }
3935
3936    #[test]
3937    fn test_anti_join_empty_neg() {
3938        // Empty neg → all rows kept
3939        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3940        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3941        let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
3942        assert_eq!(result.len(), 1);
3943        assert_eq!(result[0].num_rows(), 3);
3944    }
3945
3946    #[test]
3947    fn test_anti_join_all_excluded() {
3948        // neg covers all body rows → empty result
3949        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3950        let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
3951        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3952        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3953        let total: usize = result.iter().map(|b| b.num_rows()).sum();
3954        assert_eq!(total, 0);
3955    }
3956
3957    #[test]
3958    fn test_multiply_prob_single_complement() {
3959        // prob=0.8, complement=0.5 → output prob=0.4; complement col removed
3960        let body = make_vid_prob_batch(&[1], &[0.8]);
3961        // Add a complement column
3962        let complement_arr = Float64Array::from(vec![0.5]);
3963        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3964        cols.push(Arc::new(complement_arr));
3965        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3966        fields.push(Arc::new(Field::new(
3967            "__complement_0",
3968            DataType::Float64,
3969            true,
3970        )));
3971        let schema = Arc::new(Schema::new(fields));
3972        let batch = RecordBatch::try_new(schema, cols).unwrap();
3973
3974        let result =
3975            multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
3976                .unwrap();
3977        assert_eq!(result.len(), 1);
3978        let out = &result[0];
3979        // Complement column should be removed
3980        assert!(out.column_by_name("__complement_0").is_none());
3981        let prob = out
3982            .column_by_name("prob")
3983            .unwrap()
3984            .as_any()
3985            .downcast_ref::<Float64Array>()
3986            .unwrap();
3987        assert!(
3988            (prob.value(0) - 0.4).abs() < 1e-10,
3989            "expected 0.4, got {}",
3990            prob.value(0)
3991        );
3992    }
3993
3994    #[test]
3995    fn test_multiply_prob_multiple_complements() {
3996        // prob=0.8, c1=0.5, c2=0.6 → 0.8×0.5×0.6=0.24
3997        let body = make_vid_prob_batch(&[1], &[0.8]);
3998        let c1 = Float64Array::from(vec![0.5]);
3999        let c2 = Float64Array::from(vec![0.6]);
4000        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
4001        cols.push(Arc::new(c1));
4002        cols.push(Arc::new(c2));
4003        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
4004        fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
4005        fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
4006        let schema = Arc::new(Schema::new(fields));
4007        let batch = RecordBatch::try_new(schema, cols).unwrap();
4008
4009        let result = multiply_prob_factors(
4010            vec![batch],
4011            Some("prob"),
4012            &["__c1".to_string(), "__c2".to_string()],
4013        )
4014        .unwrap();
4015        let out = &result[0];
4016        assert!(out.column_by_name("__c1").is_none());
4017        assert!(out.column_by_name("__c2").is_none());
4018        let prob = out
4019            .column_by_name("prob")
4020            .unwrap()
4021            .as_any()
4022            .downcast_ref::<Float64Array>()
4023            .unwrap();
4024        assert!(
4025            (prob.value(0) - 0.24).abs() < 1e-10,
4026            "expected 0.24, got {}",
4027            prob.value(0)
4028        );
4029    }
4030
4031    #[test]
4032    fn test_multiply_prob_no_prob_column() {
4033        // No prob column → combined complements become the output
4034        use arrow_array::UInt64Array;
4035        let schema = Arc::new(Schema::new(vec![
4036            Field::new("vid", DataType::UInt64, true),
4037            Field::new("__c1", DataType::Float64, true),
4038        ]));
4039        let batch = RecordBatch::try_new(
4040            schema,
4041            vec![
4042                Arc::new(UInt64Array::from(vec![1u64])),
4043                Arc::new(Float64Array::from(vec![0.7])),
4044            ],
4045        )
4046        .unwrap();
4047
4048        let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
4049        let out = &result[0];
4050        // __c1 should be removed since it's a complement column
4051        assert!(out.column_by_name("__c1").is_none());
4052        // Only vid column remains
4053        assert_eq!(out.num_columns(), 1);
4054    }
4055}