Skip to main content

uni_query/query/df_graph/
locy_fixpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Fixpoint iteration operator for recursive Locy strata.
5//!
6//! `FixpointExec` drives semi-naive evaluation: it repeatedly evaluates the rules
7//! in a recursive stratum, feeding back deltas until no new facts are produced.
8
9use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11    ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12    extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17    ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::planner::LogicalPlan;
22use arrow_array::RecordBatch;
23use arrow_row::{RowConverter, SortField};
24use arrow_schema::SchemaRef;
25use datafusion::common::JoinType;
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
29use datafusion::physical_plan::memory::MemoryStream;
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
32use futures::Stream;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::pin::Pin;
38use std::sync::{Arc, RwLock as StdRwLock};
39use std::task::{Context, Poll};
40use std::time::{Duration, Instant};
41use uni_common::Value;
42use uni_common::core::schema::Schema as UniSchema;
43use uni_cypher::ast::Expr;
44use uni_locy::RuntimeWarning;
45use uni_store::storage::manager::StorageManager;
46
47// ---------------------------------------------------------------------------
48// DerivedScanRegistry — injection point for IS-ref data into subplans
49// ---------------------------------------------------------------------------
50
51/// A single entry in the derived scan registry.
52///
53/// Each entry corresponds to one `LocyDerivedScan` node in the logical plan tree.
54/// The `data` handle is shared with the logical plan node so that writing data here
55/// makes it visible when the subplan is re-planned and executed.
56#[derive(Debug)]
57pub struct DerivedScanEntry {
58    /// Index matching the `scan_index` in `LocyDerivedScan`.
59    pub scan_index: usize,
60    /// Name of the rule this scan reads from.
61    pub rule_name: String,
62    /// Whether this is a self-referential scan (rule references itself).
63    pub is_self_ref: bool,
64    /// Shared data handle — write batches here to inject into subplans.
65    pub data: Arc<RwLock<Vec<RecordBatch>>>,
66    /// Schema of the derived relation.
67    pub schema: SchemaRef,
68}
69
70/// Registry of derived scan handles for fixpoint iteration.
71///
72/// During fixpoint, each clause body may reference derived relations via
73/// `LocyDerivedScan` nodes. The registry maps scan indices to shared data
74/// handles so the fixpoint loop can inject delta/full facts before each
75/// iteration.
76#[derive(Debug, Default)]
77pub struct DerivedScanRegistry {
78    entries: Vec<DerivedScanEntry>,
79}
80
81impl DerivedScanRegistry {
82    /// Create a new empty registry.
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Add an entry to the registry.
88    pub fn add(&mut self, entry: DerivedScanEntry) {
89        self.entries.push(entry);
90    }
91
92    /// Get an entry by scan index.
93    pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
94        self.entries.iter().find(|e| e.scan_index == scan_index)
95    }
96
97    /// Write data into a scan entry's shared handle.
98    pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
99        if let Some(entry) = self.get(scan_index) {
100            let mut guard = entry.data.write();
101            *guard = batches;
102        }
103    }
104
105    /// Get all entries for a given rule name.
106    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
107        self.entries
108            .iter()
109            .filter(|e| e.rule_name == rule_name)
110            .collect()
111    }
112}
113
114// ---------------------------------------------------------------------------
115// MonotonicAggState — tracking monotonic aggregates across iterations
116// ---------------------------------------------------------------------------
117
118/// Monotonic aggregate binding: maps a fold name to its aggregate kind and column.
119#[derive(Debug, Clone)]
120pub struct MonotonicFoldBinding {
121    pub fold_name: String,
122    pub kind: crate::query::df_graph::locy_fold::FoldAggKind,
123    pub input_col_index: usize,
124    /// Column name for name-based resolution (more robust than positional index).
125    pub input_col_name: Option<String>,
126}
127
128/// Tracks monotonic aggregate accumulators across fixpoint iterations.
129///
130/// After each iteration, accumulators are updated and compared to their previous
131/// snapshot. The fixpoint has converged (w.r.t. aggregates) when all accumulators
132/// are stable (no change between iterations).
133#[derive(Debug)]
134pub struct MonotonicAggState {
135    /// Current accumulator values keyed by (group_key, fold_name).
136    accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
137    /// Snapshot from the previous iteration for stability check.
138    prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
139    /// Bindings describing which aggregates to track.
140    bindings: Vec<MonotonicFoldBinding>,
141}
142
143impl MonotonicAggState {
144    /// Create a new monotonic aggregate state.
145    pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
146        Self {
147            accumulators: HashMap::new(),
148            prev_snapshot: HashMap::new(),
149            bindings,
150        }
151    }
152
153    /// Update accumulators with new delta batches.
154    ///
155    /// Returns `true` if any accumulator value changed. When `strict` is
156    /// `true`, Nor/Prod inputs outside `[0, 1]` produce an error instead
157    /// of being clamped.
158    pub fn update(
159        &mut self,
160        key_indices: &[usize],
161        delta_batches: &[RecordBatch],
162        strict: bool,
163    ) -> DFResult<bool> {
164        use crate::query::df_graph::locy_fold::FoldAggKind;
165
166        let mut changed = false;
167        for batch in delta_batches {
168            for row_idx in 0..batch.num_rows() {
169                let group_key = extract_scalar_key(batch, key_indices, row_idx);
170                for binding in &self.bindings {
171                    // Resolve input column by name first, fall back to index.
172                    let idx = binding
173                        .input_col_name
174                        .as_ref()
175                        .and_then(|name| batch.schema().index_of(name).ok())
176                        .unwrap_or(binding.input_col_index);
177                    if idx >= batch.num_columns() {
178                        continue;
179                    }
180                    let col = batch.column(idx);
181                    let val = extract_f64(col.as_ref(), row_idx);
182                    if let Some(val) = val {
183                        let map_key = (group_key.clone(), binding.fold_name.clone());
184                        let entry = self
185                            .accumulators
186                            .entry(map_key)
187                            .or_insert(binding.kind.identity().unwrap_or(0.0));
188                        let old = *entry;
189                        match binding.kind {
190                            FoldAggKind::Sum | FoldAggKind::Count => *entry += val,
191                            FoldAggKind::Max => {
192                                if val > *entry {
193                                    *entry = val;
194                                }
195                            }
196                            FoldAggKind::Min => {
197                                if val < *entry {
198                                    *entry = val;
199                                }
200                            }
201                            FoldAggKind::Nor => {
202                                if strict && !(0.0..=1.0).contains(&val) {
203                                    return Err(datafusion::error::DataFusionError::Execution(
204                                        format!(
205                                            "strict_probability_domain: MNOR input {val} is outside [0, 1]"
206                                        ),
207                                    ));
208                                }
209                                if !strict && !(0.0..=1.0).contains(&val) {
210                                    tracing::warn!(
211                                        "MNOR input {val} outside [0,1], clamped to {}",
212                                        val.clamp(0.0, 1.0)
213                                    );
214                                }
215                                let p = val.clamp(0.0, 1.0);
216                                *entry = 1.0 - (1.0 - *entry) * (1.0 - p);
217                            }
218                            FoldAggKind::Prod => {
219                                if strict && !(0.0..=1.0).contains(&val) {
220                                    return Err(datafusion::error::DataFusionError::Execution(
221                                        format!(
222                                            "strict_probability_domain: MPROD input {val} is outside [0, 1]"
223                                        ),
224                                    ));
225                                }
226                                if !strict && !(0.0..=1.0).contains(&val) {
227                                    tracing::warn!(
228                                        "MPROD input {val} outside [0,1], clamped to {}",
229                                        val.clamp(0.0, 1.0)
230                                    );
231                                }
232                                let p = val.clamp(0.0, 1.0);
233                                *entry *= p;
234                            }
235                            _ => {}
236                        }
237                        if (*entry - old).abs() > f64::EPSILON {
238                            changed = true;
239                        }
240                    }
241                }
242            }
243        }
244        Ok(changed)
245    }
246
247    /// Take a snapshot of current accumulators for stability comparison.
248    pub fn snapshot(&mut self) {
249        self.prev_snapshot = self.accumulators.clone();
250    }
251
252    /// Check if accumulators are stable (no change since last snapshot).
253    pub fn is_stable(&self) -> bool {
254        if self.accumulators.len() != self.prev_snapshot.len() {
255            return false;
256        }
257        for (key, val) in &self.accumulators {
258            match self.prev_snapshot.get(key) {
259                Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
260                _ => return false,
261            }
262        }
263        true
264    }
265
266    /// Test-only accessor for accumulator values.
267    #[cfg(test)]
268    pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
269        self.accumulators.get(key).copied()
270    }
271}
272
273/// Extract f64 value from an Arrow column at a given row index.
274fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
275    if col.is_null(row_idx) {
276        return None;
277    }
278    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
279        Some(arr.value(row_idx))
280    } else {
281        col.as_any()
282            .downcast_ref::<arrow_array::Int64Array>()
283            .map(|arr| arr.value(row_idx) as f64)
284    }
285}
286
287// ---------------------------------------------------------------------------
288// RowDedupState — Arrow RowConverter-based persistent dedup set
289// ---------------------------------------------------------------------------
290
291/// Arrow-native row deduplication using [`RowConverter`].
292///
293/// Unlike the legacy `HashSet<Vec<ScalarKey>>` approach, this struct maintains a
294/// persistent `seen` set across iterations so per-iteration cost is O(M) where M
295/// is the number of candidate rows — the full facts table is never re-scanned.
296struct RowDedupState {
297    converter: RowConverter,
298    seen: HashSet<Box<[u8]>>,
299}
300
301impl RowDedupState {
302    /// Try to build a `RowDedupState` for the given schema.
303    ///
304    /// Returns `None` if any column type is not supported by `RowConverter`
305    /// (triggers legacy fallback).
306    fn try_new(schema: &SchemaRef) -> Option<Self> {
307        let fields: Vec<SortField> = schema
308            .fields()
309            .iter()
310            .map(|f| SortField::new(f.data_type().clone()))
311            .collect();
312        match RowConverter::new(fields) {
313            Ok(converter) => Some(Self {
314                converter,
315                seen: HashSet::new(),
316            }),
317            Err(e) => {
318                tracing::warn!(
319                    "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
320                    e
321                );
322                None
323            }
324        }
325    }
326
327    /// Populate the seen set from existing fact batches.
328    ///
329    /// Used after BEST BY in-loop pruning replaces the fact set, so that delta
330    /// computation in subsequent iterations correctly recognizes surviving facts.
331    fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
332        self.seen.clear();
333        for batch in facts {
334            if batch.num_rows() == 0 {
335                continue;
336            }
337            let arrays: Vec<_> = batch.columns().to_vec();
338            if let Ok(rows) = self.converter.convert_columns(&arrays) {
339                for row_idx in 0..batch.num_rows() {
340                    let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
341                    self.seen.insert(row_bytes);
342                }
343            }
344        }
345    }
346
347    /// Filter `candidates` to only rows not yet seen, updating the persistent set.
348    ///
349    /// Both cross-iteration dedup (rows already accepted in prior iterations) and
350    /// within-batch dedup (duplicate rows in a single candidate batch) are handled
351    /// in a single pass.
352    fn compute_delta(
353        &mut self,
354        candidates: &[RecordBatch],
355        schema: &SchemaRef,
356    ) -> DFResult<Vec<RecordBatch>> {
357        let mut delta_batches = Vec::new();
358        for batch in candidates {
359            if batch.num_rows() == 0 {
360                continue;
361            }
362
363            // Vectorized encoding of all rows in this batch.
364            let arrays: Vec<_> = batch.columns().to_vec();
365            let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
366
367            // One pass: check+insert into persistent seen set.
368            let mut keep = Vec::with_capacity(batch.num_rows());
369            for row_idx in 0..batch.num_rows() {
370                let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
371                keep.push(self.seen.insert(row_bytes));
372            }
373
374            let keep_mask = arrow_array::BooleanArray::from(keep);
375            let new_cols = batch
376                .columns()
377                .iter()
378                .map(|col| {
379                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
380                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
381                    })
382                })
383                .collect::<DFResult<Vec<_>>>()?;
384
385            if new_cols.first().is_some_and(|c| !c.is_empty()) {
386                let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
387                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
388                })?;
389                delta_batches.push(filtered);
390            }
391        }
392        Ok(delta_batches)
393    }
394}
395
396// ---------------------------------------------------------------------------
397// FixpointState — per-rule delta tracking during fixpoint iteration
398// ---------------------------------------------------------------------------
399
400/// Per-rule state for fixpoint iteration.
401///
402/// Tracks accumulated facts and the delta (new facts from the latest iteration).
403/// Deduplication uses Arrow [`RowConverter`] with a persistent seen set (O(M) per
404/// iteration) when supported, with a legacy `HashSet<Vec<ScalarKey>>` fallback.
405pub struct FixpointState {
406    rule_name: String,
407    facts: Vec<RecordBatch>,
408    delta: Vec<RecordBatch>,
409    schema: SchemaRef,
410    key_column_indices: Vec<usize>,
411    /// KEY column names for recomputing indices after schema reconciliation.
412    key_column_names: Vec<String>,
413    /// All column indices for full-row dedup (legacy path only).
414    all_column_indices: Vec<usize>,
415    /// Running total of facts bytes for memory limit tracking.
416    facts_bytes: usize,
417    /// Maximum bytes allowed for this derived relation.
418    max_derived_bytes: usize,
419    /// Optional monotonic aggregate tracking.
420    monotonic_agg: Option<MonotonicAggState>,
421    /// Arrow RowConverter-based dedup state; `None` triggers legacy fallback.
422    row_dedup: Option<RowDedupState>,
423    /// Whether strict probability domain checks are enabled.
424    strict_probability_domain: bool,
425}
426
427impl FixpointState {
428    /// Create a new fixpoint state for a rule.
429    pub fn new(
430        rule_name: String,
431        schema: SchemaRef,
432        key_column_indices: Vec<usize>,
433        max_derived_bytes: usize,
434        monotonic_agg: Option<MonotonicAggState>,
435        strict_probability_domain: bool,
436    ) -> Self {
437        let num_cols = schema.fields().len();
438        let row_dedup = RowDedupState::try_new(&schema);
439        let key_column_names: Vec<String> = key_column_indices
440            .iter()
441            .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
442            .collect();
443        Self {
444            rule_name,
445            facts: Vec::new(),
446            delta: Vec::new(),
447            schema,
448            key_column_indices,
449            key_column_names,
450            all_column_indices: (0..num_cols).collect(),
451            facts_bytes: 0,
452            max_derived_bytes,
453            monotonic_agg,
454            row_dedup,
455            strict_probability_domain,
456        }
457    }
458
459    /// Reconcile the pre-computed schema with the actual physical plan output.
460    ///
461    /// `infer_expr_type` may guess wrong (e.g. `Property → Float64` for a
462    /// string column).  When the first real batch arrives with a different
463    /// schema, update ours so that `RowDedupState` / `RecordBatch::try_new`
464    /// use the correct types.
465    fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
466        if self.schema.fields() != actual_schema.fields() {
467            tracing::debug!(
468                rule = %self.rule_name,
469                "Reconciling fixpoint schema from physical plan output",
470            );
471            self.schema = Arc::clone(actual_schema);
472            self.row_dedup = RowDedupState::try_new(&self.schema);
473            // Recompute key_column_indices from stored KEY column names.
474            // Without this, FoldExec groups by wrong columns when the
475            // physical plan reorders columns vs the pre-inferred schema.
476            let new_indices: Vec<usize> = self
477                .key_column_names
478                .iter()
479                .filter_map(|name| actual_schema.index_of(name).ok())
480                .collect();
481            if new_indices.len() == self.key_column_names.len() {
482                self.key_column_indices = new_indices;
483            }
484            // else: not all KEY columns found in new schema — keep original indices
485            let num_cols = actual_schema.fields().len();
486            self.all_column_indices = (0..num_cols).collect();
487        }
488    }
489
490    /// Merge candidate rows into facts, computing delta (truly new rows).
491    ///
492    /// Returns `true` if any new facts were added.
493    pub async fn merge_delta(
494        &mut self,
495        candidates: Vec<RecordBatch>,
496        task_ctx: Option<Arc<TaskContext>>,
497    ) -> DFResult<bool> {
498        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
499            self.delta.clear();
500            return Ok(false);
501        }
502
503        // Reconcile schema from the first non-empty candidate batch.
504        // The physical plan's output types are authoritative over the
505        // planner's inferred types.
506        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
507            self.reconcile_schema(&first.schema());
508        }
509
510        // Round floats for stable dedup
511        let candidates = round_float_columns(&candidates);
512
513        // Compute delta: rows in candidates not already in facts
514        let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
515
516        if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
517            self.delta.clear();
518            // Update monotonic aggs even with empty delta (for stability check)
519            if let Some(ref mut agg) = self.monotonic_agg {
520                agg.snapshot();
521            }
522            return Ok(false);
523        }
524
525        // Check memory limit
526        let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
527        if self.facts_bytes + delta_bytes > self.max_derived_bytes {
528            return Err(datafusion::error::DataFusionError::Execution(
529                LocyRuntimeError::MemoryLimitExceeded {
530                    rule: self.rule_name.clone(),
531                    bytes: self.facts_bytes + delta_bytes,
532                    limit: self.max_derived_bytes,
533                }
534                .to_string(),
535            ));
536        }
537
538        // Update monotonic aggs
539        if let Some(ref mut agg) = self.monotonic_agg {
540            agg.snapshot();
541            agg.update(
542                &self.key_column_indices,
543                &delta,
544                self.strict_probability_domain,
545            )?;
546        }
547
548        // Append delta to facts
549        self.facts_bytes += delta_bytes;
550        self.facts.extend(delta.iter().cloned());
551        self.delta = delta;
552
553        Ok(true)
554    }
555
556    /// Dispatch to vectorized LeftAntiJoin, Arrow RowConverter dedup, or legacy ScalarKey dedup.
557    ///
558    /// Priority order:
559    /// 1. `arrow_left_anti_dedup` when `total_existing >= DEDUP_ANTI_JOIN_THRESHOLD` and task_ctx available.
560    /// 2. `RowDedupState` (persistent HashSet, O(M) per iteration) when schema is supported.
561    /// 3. `compute_delta_legacy` (rebuilds from facts, fallback for unsupported column types).
562    async fn compute_delta(
563        &mut self,
564        candidates: &[RecordBatch],
565        task_ctx: Option<&Arc<TaskContext>>,
566    ) -> DFResult<Vec<RecordBatch>> {
567        let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
568        if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
569            && let Some(ctx) = task_ctx
570        {
571            return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
572                .await;
573        }
574        if let Some(ref mut rd) = self.row_dedup {
575            rd.compute_delta(candidates, &self.schema)
576        } else {
577            self.compute_delta_legacy(candidates)
578        }
579    }
580
581    /// Legacy dedup: rebuild a `HashSet<Vec<ScalarKey>>` from all facts each call.
582    ///
583    /// Used as fallback when `RowConverter` does not support the schema's column types.
584    fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
585        // Build set of existing fact row keys (ALL columns)
586        let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
587        for batch in &self.facts {
588            for row_idx in 0..batch.num_rows() {
589                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
590                existing.insert(key);
591            }
592        }
593
594        let mut delta_batches = Vec::new();
595        for batch in candidates {
596            if batch.num_rows() == 0 {
597                continue;
598            }
599            // Filter to only new rows
600            let mut keep = Vec::with_capacity(batch.num_rows());
601            for row_idx in 0..batch.num_rows() {
602                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
603                keep.push(!existing.contains(&key));
604            }
605
606            // Also dedup within the candidate batch itself
607            for (row_idx, kept) in keep.iter_mut().enumerate() {
608                if *kept {
609                    let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
610                    if !existing.insert(key) {
611                        *kept = false;
612                    }
613                }
614            }
615
616            let keep_mask = arrow_array::BooleanArray::from(keep);
617            let new_rows = batch
618                .columns()
619                .iter()
620                .map(|col| {
621                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
622                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
623                    })
624                })
625                .collect::<DFResult<Vec<_>>>()?;
626
627            if new_rows.first().is_some_and(|c| !c.is_empty()) {
628                let filtered =
629                    RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
630                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
631                    })?;
632                delta_batches.push(filtered);
633            }
634        }
635
636        Ok(delta_batches)
637    }
638
639    /// Check if this rule has converged (no new facts and aggs stable).
640    pub fn is_converged(&self) -> bool {
641        let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
642        let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
643        delta_empty && agg_stable
644    }
645
646    /// Get all accumulated facts.
647    pub fn all_facts(&self) -> &[RecordBatch] {
648        &self.facts
649    }
650
651    /// Get the delta from the latest iteration.
652    pub fn all_delta(&self) -> &[RecordBatch] {
653        &self.delta
654    }
655
656    /// Consume self and return facts.
657    pub fn into_facts(self) -> Vec<RecordBatch> {
658        self.facts
659    }
660
661    /// Merge candidates using BEST BY semantics.
662    ///
663    /// Combines existing facts with new candidates, keeping only the best row
664    /// per KEY group according to `sort_criteria`. Returns `true` if the
665    /// best-per-KEY fact set actually changed (a genuinely better value was
666    /// found or a new KEY appeared).
667    ///
668    /// This replaces `merge_delta` for rules with BEST BY, enabling convergence
669    /// on cyclic graphs where dominated ALONG values would otherwise produce an
670    /// unbounded stream of "new" full-row facts.
671    pub fn merge_best_by(
672        &mut self,
673        candidates: Vec<RecordBatch>,
674        sort_criteria: &[SortCriterion],
675    ) -> DFResult<bool> {
676        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
677            self.delta.clear();
678            return Ok(false);
679        }
680
681        // Reconcile schema from the first non-empty candidate batch.
682        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
683            self.reconcile_schema(&first.schema());
684        }
685
686        // Round floats for stable dedup.
687        let candidates = round_float_columns(&candidates);
688
689        // Snapshot existing best-per-KEY facts for change detection.
690        let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
691            self.build_key_criteria_map(sort_criteria);
692
693        // Concat existing facts + new candidates.
694        let mut all_batches = self.facts.clone();
695        all_batches.extend(candidates);
696        let all_batches: Vec<_> = all_batches
697            .into_iter()
698            .filter(|b| b.num_rows() > 0)
699            .collect();
700        if all_batches.is_empty() {
701            self.delta.clear();
702            return Ok(false);
703        }
704
705        let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
706            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
707
708        if combined.num_rows() == 0 {
709            self.delta.clear();
710            return Ok(false);
711        }
712
713        // Sort by KEY ASC then criteria, so the best row per KEY group comes
714        // first.
715        let mut sort_columns = Vec::new();
716        for &ki in &self.key_column_indices {
717            if ki >= combined.num_columns() {
718                continue;
719            }
720            sort_columns.push(arrow::compute::SortColumn {
721                values: Arc::clone(combined.column(ki)),
722                options: Some(arrow::compute::SortOptions {
723                    descending: false,
724                    nulls_first: false,
725                }),
726            });
727        }
728        for criterion in sort_criteria {
729            if criterion.col_index >= combined.num_columns() {
730                continue;
731            }
732            sort_columns.push(arrow::compute::SortColumn {
733                values: Arc::clone(combined.column(criterion.col_index)),
734                options: Some(arrow::compute::SortOptions {
735                    descending: !criterion.ascending,
736                    nulls_first: criterion.nulls_first,
737                }),
738            });
739        }
740
741        let sorted_indices =
742            arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
743        let sorted_columns: Vec<_> = combined
744            .columns()
745            .iter()
746            .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
747            .collect::<Result<Vec<_>, _>>()
748            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
749        let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
750            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
751
752        // Dedup: keep first (best) row per KEY group.
753        let mut keep_indices: Vec<u32> = Vec::new();
754        let mut prev_key: Option<Vec<ScalarKey>> = None;
755        for row_idx in 0..sorted.num_rows() {
756            let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
757            let is_new_group = match &prev_key {
758                None => true,
759                Some(prev) => *prev != key,
760            };
761            if is_new_group {
762                keep_indices.push(row_idx as u32);
763                prev_key = Some(key);
764            }
765        }
766
767        let keep_array = arrow_array::UInt32Array::from(keep_indices);
768        let output_columns: Vec<_> = sorted
769            .columns()
770            .iter()
771            .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
772            .collect::<Result<Vec<_>, _>>()
773            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
774        let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
775            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
776
777        // Detect whether the best-per-KEY set actually changed.
778        let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
779            let mut map = HashMap::new();
780            for row_idx in 0..pruned.num_rows() {
781                let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
782                let criteria: Vec<ScalarKey> = sort_criteria
783                    .iter()
784                    .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
785                    .collect();
786                map.insert(key, criteria);
787            }
788            map
789        };
790        let changed = old_best != new_best;
791
792        tracing::debug!(
793            rule = %self.rule_name,
794            old_keys = old_best.len(),
795            new_keys = new_best.len(),
796            changed = changed,
797            "BEST BY merge"
798        );
799
800        // Replace facts with the pruned set.
801        self.facts_bytes = batch_byte_size(&pruned);
802        self.facts = vec![pruned];
803        if changed {
804            // Delta is conceptually the new/improved facts, but since we
805            // replaced the entire set, just mark delta non-empty.
806            self.delta = self.facts.clone();
807        } else {
808            self.delta.clear();
809        }
810
811        // Rebuild row dedup from pruned facts for consistency.
812        self.row_dedup = RowDedupState::try_new(&self.schema);
813        if let Some(ref mut rd) = self.row_dedup {
814            rd.ingest_existing(&self.facts, &self.schema);
815        }
816
817        Ok(changed)
818    }
819
820    /// Build a map from KEY column values to sort criteria values.
821    fn build_key_criteria_map(
822        &self,
823        sort_criteria: &[SortCriterion],
824    ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
825        let mut map = HashMap::new();
826        for batch in &self.facts {
827            for row_idx in 0..batch.num_rows() {
828                let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
829                let criteria: Vec<ScalarKey> = sort_criteria
830                    .iter()
831                    .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
832                    .collect();
833                map.insert(key, criteria);
834            }
835        }
836        map
837    }
838}
839
840/// Estimate byte size of a RecordBatch.
841fn batch_byte_size(batch: &RecordBatch) -> usize {
842    batch
843        .columns()
844        .iter()
845        .map(|col| col.get_buffer_memory_size())
846        .sum()
847}
848
849// ---------------------------------------------------------------------------
850// Float rounding for stable dedup
851// ---------------------------------------------------------------------------
852
853/// Round all Float64 columns to 12 decimal places for stable dedup.
854fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
855    batches
856        .iter()
857        .map(|batch| {
858            let schema = batch.schema();
859            let has_float = schema
860                .fields()
861                .iter()
862                .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
863            if !has_float {
864                return batch.clone();
865            }
866
867            let columns: Vec<arrow_array::ArrayRef> = batch
868                .columns()
869                .iter()
870                .enumerate()
871                .map(|(i, col)| {
872                    if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
873                        let arr = col
874                            .as_any()
875                            .downcast_ref::<arrow_array::Float64Array>()
876                            .unwrap();
877                        let rounded: arrow_array::Float64Array = arr
878                            .iter()
879                            .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
880                            .collect();
881                        Arc::new(rounded) as arrow_array::ArrayRef
882                    } else {
883                        Arc::clone(col)
884                    }
885                })
886                .collect();
887
888            RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
889        })
890        .collect()
891}
892
893// ---------------------------------------------------------------------------
894// LeftAntiJoin delta deduplication
895// ---------------------------------------------------------------------------
896
897/// Row threshold above which the vectorized Arrow LeftAntiJoin dedup path is used.
898///
899/// Below this threshold the persistent `RowDedupState` HashSet is O(M) and
900/// avoids rebuilding the existing-row set; above it DataFusion's vectorized
901/// HashJoinExec is more cache-efficient.
902const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
903
904/// Deduplicate `candidates` against `existing` using DataFusion's HashJoinExec.
905///
906/// Returns rows in `candidates` that do not appear in `existing` (LeftAnti semantics).
907/// `null_equals_null = true` so NULLs are treated as equal for dedup purposes.
908async fn arrow_left_anti_dedup(
909    candidates: Vec<RecordBatch>,
910    existing: &[RecordBatch],
911    schema: &SchemaRef,
912    task_ctx: &Arc<TaskContext>,
913) -> DFResult<Vec<RecordBatch>> {
914    if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
915        return Ok(candidates);
916    }
917
918    let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
919    let right: Arc<dyn ExecutionPlan> =
920        Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
921
922    let on: Vec<(
923        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
924        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
925    )> = schema
926        .fields()
927        .iter()
928        .enumerate()
929        .map(|(i, field)| {
930            let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
931                datafusion::physical_plan::expressions::Column::new(field.name(), i),
932            );
933            let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
934                datafusion::physical_plan::expressions::Column::new(field.name(), i),
935            );
936            (l, r)
937        })
938        .collect();
939
940    if on.is_empty() {
941        return Ok(vec![]);
942    }
943
944    let join = HashJoinExec::try_new(
945        left,
946        right,
947        on,
948        None,
949        &JoinType::LeftAnti,
950        None,
951        PartitionMode::CollectLeft,
952        datafusion::common::NullEquality::NullEqualsNull,
953    )?;
954
955    let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
956    collect_all_partitions(&join_arc, task_ctx.clone()).await
957}
958
959// ---------------------------------------------------------------------------
960// Plan types for fixpoint rules
961// ---------------------------------------------------------------------------
962
963/// IS-ref binding: a reference from a clause body to a derived relation.
964#[derive(Debug, Clone)]
965pub struct IsRefBinding {
966    /// Index into the DerivedScanRegistry.
967    pub derived_scan_index: usize,
968    /// Name of the rule being referenced.
969    pub rule_name: String,
970    /// Whether this is a self-reference (rule references itself).
971    pub is_self_ref: bool,
972    /// Whether this is a negated reference (NOT IS).
973    pub negated: bool,
974    /// For negated IS-refs: `(left_body_col, right_derived_col)` pairs for anti-join filtering.
975    ///
976    /// `left_body_col` is the VID column in the clause body (e.g., `"n._vid"`);
977    /// `right_derived_col` is the corresponding KEY column in the negated rule's facts (e.g., `"n"`).
978    /// Empty for non-negated IS-refs.
979    pub anti_join_cols: Vec<(String, String)>,
980    /// Whether the target rule has a PROB column.
981    pub target_has_prob: bool,
982    /// Name of the PROB column in the target rule, if any.
983    pub target_prob_col: Option<String>,
984    /// `(body_col, derived_col)` pairs for provenance tracking.
985    ///
986    /// Used by shared-proof detection to find which source facts a derived row
987    /// consumed. Populated for all IS-refs (not just negated ones).
988    pub provenance_join_cols: Vec<(String, String)>,
989}
990
991/// A single clause (body) within a fixpoint rule.
992#[derive(Debug)]
993pub struct FixpointClausePlan {
994    /// The logical plan for the clause body.
995    pub body_logical: LogicalPlan,
996    /// IS-ref bindings used by this clause.
997    pub is_ref_bindings: Vec<IsRefBinding>,
998    /// Priority value for this clause (if PRIORITY semantics apply).
999    pub priority: Option<i64>,
1000    /// ALONG binding variable names propagated from the planner.
1001    pub along_bindings: Vec<String>,
1002}
1003
1004/// Physical plan for a single rule in a fixpoint stratum.
1005#[derive(Debug)]
1006pub struct FixpointRulePlan {
1007    /// Rule name.
1008    pub name: String,
1009    /// Clause bodies (each evaluates to candidate rows).
1010    pub clauses: Vec<FixpointClausePlan>,
1011    /// Output schema for this rule's derived relation.
1012    pub yield_schema: SchemaRef,
1013    /// Indices of KEY columns within yield_schema.
1014    pub key_column_indices: Vec<usize>,
1015    /// Priority value (if PRIORITY semantics apply).
1016    pub priority: Option<i64>,
1017    /// Whether this rule has FOLD semantics.
1018    pub has_fold: bool,
1019    /// FOLD bindings for post-fixpoint aggregation.
1020    pub fold_bindings: Vec<FoldBinding>,
1021    /// Post-FOLD filter expressions (HAVING semantics).
1022    pub having: Vec<Expr>,
1023    /// Whether this rule has BEST BY semantics.
1024    pub has_best_by: bool,
1025    /// BEST BY sort criteria for post-fixpoint selection.
1026    pub best_by_criteria: Vec<SortCriterion>,
1027    /// Whether this rule has PRIORITY semantics.
1028    pub has_priority: bool,
1029    /// Whether BEST BY should apply a deterministic secondary sort for
1030    /// tie-breaking. When false, tied rows are selected non-deterministically
1031    /// (faster but not repeatable across runs).
1032    pub deterministic: bool,
1033    /// Name of the PROB column in this rule's yield schema, if any.
1034    pub prob_column_name: Option<String>,
1035}
1036
1037// ---------------------------------------------------------------------------
1038// run_fixpoint_loop — the core semi-naive iteration algorithm
1039// ---------------------------------------------------------------------------
1040
1041/// Run the semi-naive fixpoint iteration loop.
1042///
1043/// Evaluates all rules in a stratum repeatedly, feeding deltas back through
1044/// derived scan handles until convergence or limits are reached.
1045#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1046async fn run_fixpoint_loop(
1047    rules: Vec<FixpointRulePlan>,
1048    max_iterations: usize,
1049    timeout: Duration,
1050    graph_ctx: Arc<GraphExecutionContext>,
1051    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1052    storage: Arc<StorageManager>,
1053    schema_info: Arc<UniSchema>,
1054    params: HashMap<String, Value>,
1055    registry: Arc<DerivedScanRegistry>,
1056    output_schema: SchemaRef,
1057    max_derived_bytes: usize,
1058    derivation_tracker: Option<Arc<ProvenanceStore>>,
1059    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1060    strict_probability_domain: bool,
1061    probability_epsilon: f64,
1062    exact_probability: bool,
1063    max_bdd_variables: usize,
1064    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1065    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1066    top_k_proofs: usize,
1067) -> DFResult<Vec<RecordBatch>> {
1068    let start = Instant::now();
1069    let task_ctx = session_ctx.read().task_ctx();
1070
1071    // Initialize per-rule state
1072    let mut states: Vec<FixpointState> = rules
1073        .iter()
1074        .map(|rule| {
1075            let monotonic_agg = if !rule.fold_bindings.is_empty() {
1076                let bindings: Vec<MonotonicFoldBinding> = rule
1077                    .fold_bindings
1078                    .iter()
1079                    .map(|fb| MonotonicFoldBinding {
1080                        fold_name: fb.output_name.clone(),
1081                        kind: fb.kind.clone(),
1082                        input_col_index: fb.input_col_index,
1083                        input_col_name: fb.input_col_name.clone(),
1084                    })
1085                    .collect();
1086                Some(MonotonicAggState::new(bindings))
1087            } else {
1088                None
1089            };
1090            FixpointState::new(
1091                rule.name.clone(),
1092                Arc::clone(&rule.yield_schema),
1093                rule.key_column_indices.clone(),
1094                max_derived_bytes,
1095                monotonic_agg,
1096                strict_probability_domain,
1097            )
1098        })
1099        .collect();
1100
1101    // Main iteration loop
1102    let mut converged = false;
1103    let mut total_iters = 0usize;
1104    for iteration in 0..max_iterations {
1105        total_iters = iteration + 1;
1106        tracing::debug!("fixpoint iteration {}", iteration);
1107        let mut any_changed = false;
1108
1109        for rule_idx in 0..rules.len() {
1110            let rule = &rules[rule_idx];
1111
1112            // Update derived scan handles for this rule's clauses
1113            update_derived_scan_handles(&registry, &states, rule_idx, &rules);
1114
1115            // Evaluate clause bodies, tracking per-clause candidates for provenance.
1116            let mut all_candidates = Vec::new();
1117            let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1118            for clause in &rule.clauses {
1119                let mut batches = execute_subplan(
1120                    &clause.body_logical,
1121                    &params,
1122                    &HashMap::new(),
1123                    &graph_ctx,
1124                    &session_ctx,
1125                    &storage,
1126                    &schema_info,
1127                )
1128                .await?;
1129                // Apply negated IS-ref semantics: probabilistic complement or anti-join.
1130                for binding in &clause.is_ref_bindings {
1131                    if binding.negated
1132                        && !binding.anti_join_cols.is_empty()
1133                        && let Some(entry) = registry.get(binding.derived_scan_index)
1134                    {
1135                        let neg_facts = entry.data.read().clone();
1136                        if !neg_facts.is_empty() {
1137                            if binding.target_has_prob && rule.prob_column_name.is_some() {
1138                                // Probabilistic complement: add 1-p column instead of filtering.
1139                                let complement_col =
1140                                    format!("__prob_complement_{}", binding.rule_name);
1141                                if let Some(prob_col) = &binding.target_prob_col {
1142                                    batches = apply_prob_complement_composite(
1143                                        batches,
1144                                        &neg_facts,
1145                                        &binding.anti_join_cols,
1146                                        prob_col,
1147                                        &complement_col,
1148                                    )?;
1149                                } else {
1150                                    // target_has_prob but no prob_col: fall back to anti-join.
1151                                    batches = apply_anti_join_composite(
1152                                        batches,
1153                                        &neg_facts,
1154                                        &binding.anti_join_cols,
1155                                    )?;
1156                                }
1157                            } else {
1158                                // Boolean exclusion: anti-join (existing behavior)
1159                                batches = apply_anti_join_composite(
1160                                    batches,
1161                                    &neg_facts,
1162                                    &binding.anti_join_cols,
1163                                )?;
1164                            }
1165                        }
1166                    }
1167                }
1168                // Multiply complement columns into the PROB column (if any) and clean up
1169                let complement_cols: Vec<String> = if !batches.is_empty() {
1170                    batches[0]
1171                        .schema()
1172                        .fields()
1173                        .iter()
1174                        .filter(|f| f.name().starts_with("__prob_complement_"))
1175                        .map(|f| f.name().clone())
1176                        .collect()
1177                } else {
1178                    vec![]
1179                };
1180                if !complement_cols.is_empty() {
1181                    batches = multiply_prob_factors(
1182                        batches,
1183                        rule.prob_column_name.as_deref(),
1184                        &complement_cols,
1185                    )?;
1186                }
1187
1188                clause_candidates.push(batches.clone());
1189                all_candidates.extend(batches);
1190            }
1191
1192            // Merge candidates into facts.
1193            // For BEST BY rules, use a specialized merge that keeps only the
1194            // best row per KEY group, enabling convergence on cyclic graphs.
1195            let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1196                states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1197            } else {
1198                states[rule_idx]
1199                    .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1200                    .await?
1201            };
1202            if changed {
1203                any_changed = true;
1204                // Record provenance for newly derived facts when tracker is present.
1205                if let Some(ref tracker) = derivation_tracker {
1206                    record_provenance(
1207                        tracker,
1208                        rule,
1209                        &states[rule_idx],
1210                        &clause_candidates,
1211                        iteration,
1212                        &registry,
1213                        top_k_proofs,
1214                    );
1215                }
1216            }
1217        }
1218
1219        // Check convergence
1220        if !any_changed && states.iter().all(|s| s.is_converged()) {
1221            tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1222            converged = true;
1223            break;
1224        }
1225
1226        // Check timeout
1227        if start.elapsed() > timeout {
1228            return Err(datafusion::error::DataFusionError::Execution(
1229                LocyRuntimeError::NonConvergence {
1230                    iterations: iteration + 1,
1231                }
1232                .to_string(),
1233            ));
1234        }
1235    }
1236
1237    // Write per-rule iteration counts to the shared slot.
1238    if let Ok(mut counts) = iteration_counts.write() {
1239        for rule in &rules {
1240            counts.insert(rule.name.clone(), total_iters);
1241        }
1242    }
1243
1244    // If we exhausted all iterations without converging, return a non-convergence error.
1245    if !converged {
1246        return Err(datafusion::error::DataFusionError::Execution(
1247            LocyRuntimeError::NonConvergence {
1248                iterations: max_iterations,
1249            }
1250            .to_string(),
1251        ));
1252    }
1253
1254    // Post-fixpoint processing per rule and collect output
1255    let task_ctx = session_ctx.read().task_ctx();
1256    let mut all_output = Vec::new();
1257
1258    for (rule_idx, state) in states.into_iter().enumerate() {
1259        let rule = &rules[rule_idx];
1260        let mut facts = state.into_facts();
1261        if facts.is_empty() {
1262            continue;
1263        }
1264
1265        // Detect shared proofs before FOLD collapses groups.
1266        let shared_info = if let Some(ref tracker) = derivation_tracker {
1267            detect_shared_lineage(rule, &facts, tracker, &warnings_slot)
1268        } else {
1269            None
1270        };
1271
1272        // Apply BDD for shared groups if exact_probability is enabled.
1273        if exact_probability
1274            && let Some(ref info) = shared_info
1275            && let Some(ref tracker) = derivation_tracker
1276        {
1277            facts = apply_exact_wmc(
1278                facts,
1279                rule,
1280                info,
1281                tracker,
1282                max_bdd_variables,
1283                &warnings_slot,
1284                &approximate_slot,
1285            )?;
1286        }
1287
1288        let processed = apply_post_fixpoint_chain(
1289            facts,
1290            rule,
1291            &task_ctx,
1292            strict_probability_domain,
1293            probability_epsilon,
1294        )
1295        .await?;
1296        all_output.extend(processed);
1297    }
1298
1299    // If no output, return empty batch with output schema
1300    if all_output.is_empty() {
1301        all_output.push(RecordBatch::new_empty(output_schema));
1302    }
1303
1304    Ok(all_output)
1305}
1306
1307// ---------------------------------------------------------------------------
1308// Provenance recording helpers
1309// ---------------------------------------------------------------------------
1310
1311/// Record provenance for all newly derived facts (rows in the current delta).
1312///
1313/// Called after `merge_delta` returns `true`. Attributes each new fact to the
1314/// clause most likely to have produced it, using first-derivation-wins semantics.
1315fn record_provenance(
1316    tracker: &Arc<ProvenanceStore>,
1317    rule: &FixpointRulePlan,
1318    state: &FixpointState,
1319    clause_candidates: &[Vec<RecordBatch>],
1320    iteration: usize,
1321    registry: &Arc<DerivedScanRegistry>,
1322    top_k_proofs: usize,
1323) {
1324    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1325
1326    // Pre-compute base fact probabilities for top-k mode.
1327    let base_probs = if top_k_proofs > 0 {
1328        tracker.base_fact_probs()
1329    } else {
1330        HashMap::new()
1331    };
1332
1333    for delta_batch in state.all_delta() {
1334        for row_idx in 0..delta_batch.num_rows() {
1335            let row_hash = format!(
1336                "{:?}",
1337                extract_scalar_key(delta_batch, &all_indices, row_idx)
1338            )
1339            .into_bytes();
1340            let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1341            let clause_index =
1342                find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1343
1344            let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1345
1346            let proof_probability = if top_k_proofs > 0 {
1347                compute_proof_probability(&support, &base_probs)
1348            } else {
1349                None
1350            };
1351
1352            let entry = ProvenanceAnnotation {
1353                rule_name: rule.name.clone(),
1354                clause_index,
1355                support,
1356                along_values: {
1357                    let along_names: Vec<String> = rule
1358                        .clauses
1359                        .get(clause_index)
1360                        .map(|c| c.along_bindings.clone())
1361                        .unwrap_or_default();
1362                    along_names
1363                        .iter()
1364                        .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1365                        .collect()
1366                },
1367                iteration,
1368                fact_row,
1369                proof_probability,
1370            };
1371            if top_k_proofs > 0 {
1372                tracker.record_top_k(row_hash, entry, top_k_proofs);
1373            } else {
1374                tracker.record(row_hash, entry);
1375            }
1376        }
1377    }
1378}
1379
1380/// Collect IS-ref input facts for a derived row using provenance join columns.
1381///
1382/// For each non-negated IS-ref binding in the clause, extracts body-side key
1383/// values from the delta row and finds matching source rows in the registry.
1384/// Returns a `ProofTerm` for each match (with the source fact hash).
1385fn collect_is_ref_inputs(
1386    rule: &FixpointRulePlan,
1387    clause_index: usize,
1388    delta_batch: &RecordBatch,
1389    row_idx: usize,
1390    registry: &Arc<DerivedScanRegistry>,
1391) -> Vec<ProofTerm> {
1392    let clause = match rule.clauses.get(clause_index) {
1393        Some(c) => c,
1394        None => return vec![],
1395    };
1396
1397    let mut inputs = Vec::new();
1398    let delta_schema = delta_batch.schema();
1399
1400    for binding in &clause.is_ref_bindings {
1401        if binding.negated {
1402            continue;
1403        }
1404        if binding.provenance_join_cols.is_empty() {
1405            continue;
1406        }
1407
1408        // Extract body-side values from the delta row for each provenance join col.
1409        let body_values: Vec<(String, ScalarKey)> = binding
1410            .provenance_join_cols
1411            .iter()
1412            .filter_map(|(body_col, _derived_col)| {
1413                let col_idx = delta_schema
1414                    .fields()
1415                    .iter()
1416                    .position(|f| f.name() == body_col)?;
1417                let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
1418                Some((body_col.clone(), key.into_iter().next()?))
1419            })
1420            .collect();
1421
1422        if body_values.len() != binding.provenance_join_cols.len() {
1423            continue;
1424        }
1425
1426        // Read current data from the registry entry for this IS-ref's rule.
1427        let entry = match registry.get(binding.derived_scan_index) {
1428            Some(e) => e,
1429            None => continue,
1430        };
1431        let source_batches = entry.data.read();
1432        let source_schema = &entry.schema;
1433
1434        // Find matching source rows and hash them.
1435        for src_batch in source_batches.iter() {
1436            let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
1437            for src_row in 0..src_batch.num_rows() {
1438                let matches = binding.provenance_join_cols.iter().enumerate().all(
1439                    |(i, (_body_col, derived_col))| {
1440                        let src_col_idx = source_schema
1441                            .fields()
1442                            .iter()
1443                            .position(|f| f.name() == derived_col);
1444                        match src_col_idx {
1445                            Some(idx) => {
1446                                let src_key = extract_scalar_key(src_batch, &[idx], src_row);
1447                                src_key.first() == Some(&body_values[i].1)
1448                            }
1449                            None => false,
1450                        }
1451                    },
1452                );
1453                if matches {
1454                    let fact_hash = format!(
1455                        "{:?}",
1456                        extract_scalar_key(src_batch, &all_src_indices, src_row)
1457                    )
1458                    .into_bytes();
1459                    inputs.push(ProofTerm {
1460                        source_rule: binding.rule_name.clone(),
1461                        base_fact_id: fact_hash,
1462                    });
1463                }
1464            }
1465        }
1466    }
1467
1468    inputs
1469}
1470
1471// ---------------------------------------------------------------------------
1472// Shared-lineage detection
1473// ---------------------------------------------------------------------------
1474
1475/// Detect KEY groups in a rule's pre-fold facts where recursive derivation
1476/// may violate the independence assumption of MNOR/MPROD.
1477///
1478/// Uses a two-tier strategy:
1479/// 1. **Precise**: If the `ProvenanceStore` has populated `support` for facts
1480///    in the group, we recursively compute lineage (Cui & Widom 2000) and
1481///    check for pairwise overlap. A shared base fact proves a dependency.
1482/// 2. **Structural fallback**: When lineage tracking is unavailable (e.g., the
1483///    IS-ref subject variables were projected away), we check whether any fact
1484///    in a multi-row group was derived by a clause that has IS-ref bindings.
1485///    Recursive derivation through shared relations is a strong signal that
1486///    proof paths may share intermediate nodes.
1487///
1488/// Per-row data collected during shared-lineage detection.
1489#[expect(
1490    dead_code,
1491    reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
1492)]
1493pub(crate) struct SharedGroupRow {
1494    pub fact_hash: Vec<u8>,
1495    pub lineage: HashSet<Vec<u8>>,
1496}
1497
1498/// Information about groups with shared proofs, returned by `detect_shared_lineage`.
1499pub(crate) struct SharedLineageInfo {
1500    /// KEY group → rows with their base fact sets.
1501    pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
1502}
1503
1504/// Build a byte key that uniquely identifies a row across all columns.
1505fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
1506    format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
1507}
1508
1509/// Emits at most one `SharedProbabilisticDependency` warning per rule.
1510/// Returns `Some(SharedLineageInfo)` if any group has shared proofs.
1511fn detect_shared_lineage(
1512    rule: &FixpointRulePlan,
1513    pre_fold_facts: &[RecordBatch],
1514    tracker: &Arc<ProvenanceStore>,
1515    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1516) -> Option<SharedLineageInfo> {
1517    use crate::query::df_graph::locy_fold::FoldAggKind;
1518    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1519
1520    // Only check rules with MNOR/MPROD fold bindings.
1521    let has_prob_fold = rule
1522        .fold_bindings
1523        .iter()
1524        .any(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1525    if !has_prob_fold {
1526        return None;
1527    }
1528
1529    // Group facts by KEY columns.
1530    let key_indices = &rule.key_column_indices;
1531    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1532
1533    let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
1534    for batch in pre_fold_facts {
1535        for row_idx in 0..batch.num_rows() {
1536            let key = extract_scalar_key(batch, key_indices, row_idx);
1537            let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1538            groups.entry(key).or_default().push(fact_hash);
1539        }
1540    }
1541
1542    let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
1543    let mut any_shared = false;
1544
1545    // Check each group with ≥2 rows.
1546    for (key, fact_hashes) in &groups {
1547        if fact_hashes.len() < 2 {
1548            continue;
1549        }
1550
1551        // Tier 1: precise base-fact overlap detection via tracker inputs.
1552        let mut has_inputs = false;
1553        let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
1554        for fh in fact_hashes {
1555            let bases = compute_lineage(fh, tracker, &mut HashSet::new());
1556            if let Some(entry) = tracker.lookup(fh)
1557                && !entry.support.is_empty()
1558            {
1559                has_inputs = true;
1560            }
1561            per_row_bases.push(bases);
1562        }
1563
1564        let shared_found = if has_inputs {
1565            // At least some facts have tracked inputs — do precise comparison.
1566            let mut found = false;
1567            'outer: for i in 0..per_row_bases.len() {
1568                for j in (i + 1)..per_row_bases.len() {
1569                    if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
1570                        found = true;
1571                        break 'outer;
1572                    }
1573                }
1574            }
1575            found
1576        } else {
1577            // Tier 2: structural fallback — check if any fact in the group was
1578            // derived by a clause with IS-ref bindings (recursive derivation).
1579            fact_hashes.iter().any(|fh| {
1580                tracker.lookup(fh).is_some_and(|entry| {
1581                    rule.clauses
1582                        .get(entry.clause_index)
1583                        .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
1584                })
1585            })
1586        };
1587
1588        if shared_found {
1589            any_shared = true;
1590            // Collect the group rows with their base facts for BDD use.
1591            let rows: Vec<SharedGroupRow> = fact_hashes
1592                .iter()
1593                .zip(per_row_bases.into_iter())
1594                .map(|(fh, bases)| SharedGroupRow {
1595                    fact_hash: fh.clone(),
1596                    lineage: bases,
1597                })
1598                .collect();
1599            shared_groups.insert(key.clone(), rows);
1600        }
1601    }
1602
1603    // Phase 5: Cross-group correlation warning.
1604    // Check if any IS-ref input fact appears in multiple KEY groups.
1605    // This is independent of within-group sharing: even rules whose KEY groups
1606    // each have only one post-fold row can exhibit cross-group correlation when
1607    // different groups consume the same IS-ref base fact.
1608    {
1609        let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
1610        for (key, fact_hashes) in &groups {
1611            for fh in fact_hashes {
1612                if let Some(entry) = tracker.lookup(fh) {
1613                    for input in &entry.support {
1614                        input_to_groups
1615                            .entry(input.base_fact_id.clone())
1616                            .or_default()
1617                            .insert(key.clone());
1618                    }
1619                }
1620            }
1621        }
1622        let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
1623        if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
1624            let already_warned = warnings.iter().any(|w| {
1625                w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
1626                    && w.rule_name == rule.name
1627            });
1628            if !already_warned {
1629                warnings.push(RuntimeWarning {
1630                    code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
1631                    message: format!(
1632                        "Rule '{}': IS-ref base facts are shared across different KEY \
1633                         groups. BDD corrects per-group probabilities but cannot account \
1634                         for cross-group correlations.",
1635                        rule.name
1636                    ),
1637                    rule_name: rule.name.clone(),
1638                    variable_count: None,
1639                    key_group: None,
1640                });
1641            }
1642        }
1643    }
1644
1645    if any_shared {
1646        if let Ok(mut warnings) = warnings_slot.write() {
1647            let already_warned = warnings.iter().any(|w| {
1648                w.code == RuntimeWarningCode::SharedProbabilisticDependency
1649                    && w.rule_name == rule.name
1650            });
1651            if !already_warned {
1652                warnings.push(RuntimeWarning {
1653                    code: RuntimeWarningCode::SharedProbabilisticDependency,
1654                    message: format!(
1655                        "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
1656                         share intermediate facts, violating the independence assumption. \
1657                         Results may overestimate probability.",
1658                        rule.name
1659                    ),
1660                    rule_name: rule.name.clone(),
1661                    variable_count: None,
1662                    key_group: None,
1663                });
1664            }
1665        }
1666        Some(SharedLineageInfo { shared_groups })
1667    } else {
1668        None
1669    }
1670}
1671
1672/// Record provenance and detect shared proofs for non-recursive strata.
1673///
1674/// Non-recursive rules are evaluated in a single pass (no fixpoint loop), so
1675/// the regular `record_provenance` + `detect_shared_lineage` path is never hit.
1676/// This function bridges that gap by recording a `ProvenanceAnnotation` for every
1677/// fact produced by each clause and then running the same two-tier detection
1678/// logic used by the recursive path.
1679pub(crate) fn record_and_detect_lineage_nonrecursive(
1680    rule: &FixpointRulePlan,
1681    tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
1682    tracker: &Arc<ProvenanceStore>,
1683    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1684    registry: &Arc<DerivedScanRegistry>,
1685    top_k_proofs: usize,
1686) -> Option<SharedLineageInfo> {
1687    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1688
1689    // Pre-compute base fact probabilities for top-k mode.
1690    let base_probs = if top_k_proofs > 0 {
1691        tracker.base_fact_probs()
1692    } else {
1693        HashMap::new()
1694    };
1695
1696    // Record provenance for each clause's facts.
1697    for (clause_index, batches) in tagged_clause_facts {
1698        for batch in batches {
1699            for row_idx in 0..batch.num_rows() {
1700                let row_hash = fact_hash_key(batch, &all_indices, row_idx);
1701                let fact_row = batch_row_to_value_map(batch, row_idx);
1702
1703                let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
1704
1705                let proof_probability = if top_k_proofs > 0 {
1706                    compute_proof_probability(&support, &base_probs)
1707                } else {
1708                    None
1709                };
1710
1711                let entry = ProvenanceAnnotation {
1712                    rule_name: rule.name.clone(),
1713                    clause_index: *clause_index,
1714                    support,
1715                    along_values: {
1716                        let along_names: Vec<String> = rule
1717                            .clauses
1718                            .get(*clause_index)
1719                            .map(|c| c.along_bindings.clone())
1720                            .unwrap_or_default();
1721                        along_names
1722                            .iter()
1723                            .filter_map(|name| {
1724                                fact_row.get(name).map(|v| (name.clone(), v.clone()))
1725                            })
1726                            .collect()
1727                    },
1728                    iteration: 0,
1729                    fact_row,
1730                    proof_probability,
1731                };
1732                if top_k_proofs > 0 {
1733                    tracker.record_top_k(row_hash, entry, top_k_proofs);
1734                } else {
1735                    tracker.record(row_hash, entry);
1736                }
1737            }
1738        }
1739    }
1740
1741    // Flatten all clause facts and run detection.
1742    let all_facts: Vec<RecordBatch> = tagged_clause_facts
1743        .iter()
1744        .flat_map(|(_, batches)| batches.iter().cloned())
1745        .collect();
1746    detect_shared_lineage(rule, &all_facts, tracker, warnings_slot)
1747}
1748
1749/// Apply exact weighted model counting (WMC) for shared-lineage groups.
1750///
1751/// Replaces multiple rows in groups with shared lineage with a single
1752/// representative row whose PROB column carries the BDD-computed exact
1753/// probability (Sang et al. 2005). For groups that exceed
1754/// `max_bdd_variables`, rows are left unchanged and a `BddLimitExceeded`
1755/// warning is emitted.
1756pub(crate) fn apply_exact_wmc(
1757    pre_fold_facts: Vec<RecordBatch>,
1758    rule: &FixpointRulePlan,
1759    shared_info: &SharedLineageInfo,
1760    tracker: &Arc<ProvenanceStore>,
1761    max_bdd_variables: usize,
1762    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1763    approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1764) -> DFResult<Vec<RecordBatch>> {
1765    use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
1766    use crate::query::df_graph::locy_fold::FoldAggKind;
1767    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1768
1769    // Find the MNOR/MPROD fold binding to know which column to overwrite.
1770    let prob_fold = rule
1771        .fold_bindings
1772        .iter()
1773        .find(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1774    let prob_fold = match prob_fold {
1775        Some(f) => f,
1776        None => return Ok(pre_fold_facts),
1777    };
1778    let semiring_op = if matches!(prob_fold.kind, FoldAggKind::Nor) {
1779        SemiringOp::Disjunction
1780    } else {
1781        SemiringOp::Conjunction
1782    };
1783    let prob_col_idx = prob_fold.input_col_index;
1784    let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
1785
1786    let key_indices = &rule.key_column_indices;
1787    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1788
1789    // Build a set of shared group keys for quick lookup.
1790    let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
1791
1792    // Phase 1: Collect all rows for each shared KEY group across all batches.
1793    // Store (batch_index, row_index) pairs for each group.
1794    struct GroupAccum {
1795        base_facts: Vec<HashSet<Vec<u8>>>,
1796        base_probs: HashMap<Vec<u8>, f64>,
1797        /// First occurrence: (batch_index, row_index) — used as representative.
1798        representative: (usize, usize),
1799        row_locations: Vec<(usize, usize)>,
1800    }
1801
1802    let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
1803    let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); // (batch_idx, row_idx)
1804
1805    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1806        for row_idx in 0..batch.num_rows() {
1807            let key = extract_scalar_key(batch, key_indices, row_idx);
1808            if shared_keys.contains(&key) {
1809                let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1810                let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
1811
1812                let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
1813                    base_facts: Vec::new(),
1814                    base_probs: HashMap::new(),
1815                    representative: (batch_idx, row_idx),
1816                    row_locations: Vec::new(),
1817                });
1818
1819                // Look up probabilities for base facts.
1820                for bf in &bases {
1821                    if !accum.base_probs.contains_key(bf)
1822                        && let Some(entry) = tracker.lookup(bf)
1823                        && let Some(val) = entry.fact_row.get(&prob_col_name)
1824                        && let Some(p) = value_to_f64(val)
1825                    {
1826                        accum.base_probs.insert(bf.clone(), p);
1827                    }
1828                }
1829
1830                accum.base_facts.push(bases);
1831                accum.row_locations.push((batch_idx, row_idx));
1832            } else {
1833                non_shared_rows.push((batch_idx, row_idx));
1834            }
1835        }
1836    }
1837
1838    // Phase 2: Compute BDD for each shared group (across all batches).
1839    // Track which (batch_idx, row_idx) pairs to keep vs drop.
1840    let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
1841    // Map of (batch_idx, row_idx) → overridden PROB value (for BDD-succeeded groups).
1842    let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
1843
1844    // All non-shared rows are kept.
1845    for &loc in &non_shared_rows {
1846        keep_rows.insert(loc);
1847    }
1848
1849    for (key, accum) in &group_accums {
1850        let bdd_result = weighted_model_count(
1851            &accum.base_facts,
1852            &accum.base_probs,
1853            semiring_op,
1854            max_bdd_variables,
1855        );
1856
1857        if bdd_result.approximated {
1858            // Emit BddLimitExceeded warning (one per key group).
1859            if let Ok(mut warnings) = warnings_slot.write() {
1860                let key_desc = format!("{key:?}");
1861                let already_warned = warnings.iter().any(|w| {
1862                    w.code == RuntimeWarningCode::BddLimitExceeded
1863                        && w.rule_name == rule.name
1864                        && w.key_group.as_deref() == Some(&key_desc)
1865                });
1866                if !already_warned {
1867                    warnings.push(RuntimeWarning {
1868                        code: RuntimeWarningCode::BddLimitExceeded,
1869                        message: format!(
1870                            "Rule '{}': BDD variable limit exceeded ({} > {}). \
1871                             Falling back to independence-mode result.",
1872                            rule.name, bdd_result.variable_count, max_bdd_variables
1873                        ),
1874                        rule_name: rule.name.clone(),
1875                        variable_count: Some(bdd_result.variable_count),
1876                        key_group: Some(key_desc),
1877                    });
1878                }
1879            }
1880            if let Ok(mut approx) = approximate_slot.write() {
1881                let key_desc = format!("{key:?}");
1882                approx.entry(rule.name.clone()).or_default().push(key_desc);
1883            }
1884            // Keep all rows unchanged.
1885            for &loc in &accum.row_locations {
1886                keep_rows.insert(loc);
1887            }
1888        } else {
1889            // BDD succeeded: keep one representative row with overridden PROB.
1890            keep_rows.insert(accum.representative);
1891            overrides.insert(accum.representative, bdd_result.probability);
1892        }
1893    }
1894
1895    // Phase 3: Build output batches by filtering kept rows per batch.
1896    let mut result_batches = Vec::new();
1897    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1898        let kept_indices: Vec<usize> = (0..batch.num_rows())
1899            .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
1900            .collect();
1901
1902        if kept_indices.is_empty() {
1903            continue;
1904        }
1905
1906        let indices = arrow::array::UInt32Array::from(
1907            kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
1908        );
1909        let mut columns: Vec<arrow::array::ArrayRef> = batch
1910            .columns()
1911            .iter()
1912            .map(|col| arrow::compute::take(col, &indices, None))
1913            .collect::<Result<Vec<_>, _>>()
1914            .map_err(arrow_err)?;
1915
1916        // Check if any kept rows have PROB overrides.
1917        let override_map: Vec<Option<f64>> = kept_indices
1918            .iter()
1919            .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
1920            .collect();
1921
1922        if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
1923            // Rebuild the PROB column with overrides.
1924            let existing_prob = columns[prob_col_idx]
1925                .as_any()
1926                .downcast_ref::<arrow::array::Float64Array>();
1927            let new_values: Vec<f64> = override_map
1928                .iter()
1929                .enumerate()
1930                .map(|(i, ov)| match ov {
1931                    Some(p) => *p,
1932                    None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
1933                })
1934                .collect();
1935            columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
1936        }
1937
1938        let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
1939        result_batches.push(result_batch);
1940    }
1941
1942    Ok(result_batches)
1943}
1944
1945/// Extract an f64 from a `Value`, supporting Float and Int.
1946fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
1947    match val {
1948        uni_common::Value::Float(f) => Some(*f),
1949        uni_common::Value::Int(i) => Some(*i as f64),
1950        _ => None,
1951    }
1952}
1953
1954/// Compute the lineage of a derived fact (Cui & Widom 2000).
1955///
1956/// Recursively traverses the provenance store to collect the set of base-level
1957/// fact hashes that contribute to this derivation. A base fact is one with no
1958/// IS-ref support (a graph-level fact). Intermediate facts are expanded
1959/// transitively through the store.
1960fn compute_lineage(
1961    fact_hash: &[u8],
1962    tracker: &Arc<ProvenanceStore>,
1963    visited: &mut HashSet<Vec<u8>>,
1964) -> HashSet<Vec<u8>> {
1965    if !visited.insert(fact_hash.to_vec()) {
1966        return HashSet::new(); // Cycle guard.
1967    }
1968
1969    match tracker.lookup(fact_hash) {
1970        Some(entry) if !entry.support.is_empty() => {
1971            let mut bases = HashSet::new();
1972            for input in &entry.support {
1973                let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
1974                bases.extend(child_bases);
1975            }
1976            bases
1977        }
1978        _ => {
1979            // Base fact (no tracker entry or no inputs).
1980            let mut set = HashSet::new();
1981            set.insert(fact_hash.to_vec());
1982            set
1983        }
1984    }
1985}
1986
1987/// Determine which clause produced a given row by checking each clause's candidates.
1988///
1989/// Returns the index of the first clause whose candidates contain a matching row.
1990/// Falls back to 0 if no match is found.
1991fn find_clause_for_row(
1992    delta_batch: &RecordBatch,
1993    row_idx: usize,
1994    all_indices: &[usize],
1995    clause_candidates: &[Vec<RecordBatch>],
1996) -> usize {
1997    let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
1998    for (clause_idx, batches) in clause_candidates.iter().enumerate() {
1999        for batch in batches {
2000            if batch.num_columns() != all_indices.len() {
2001                continue;
2002            }
2003            for r in 0..batch.num_rows() {
2004                if extract_scalar_key(batch, all_indices, r) == target_key {
2005                    return clause_idx;
2006                }
2007            }
2008        }
2009    }
2010    0
2011}
2012
2013/// Convert a single row from a `RecordBatch` at `row_idx` into a `HashMap<String, Value>`.
2014fn batch_row_to_value_map(
2015    batch: &RecordBatch,
2016    row_idx: usize,
2017) -> std::collections::HashMap<String, Value> {
2018    use uni_store::storage::arrow_convert::arrow_to_value;
2019
2020    let schema = batch.schema();
2021    schema
2022        .fields()
2023        .iter()
2024        .enumerate()
2025        .map(|(col_idx, field)| {
2026            let col = batch.column(col_idx);
2027            let val = arrow_to_value(col.as_ref(), row_idx, None);
2028            (field.name().clone(), val)
2029        })
2030        .collect()
2031}
2032
2033/// Filter `batches` to exclude rows where `left_col` VID appears in `neg_facts[right_col]`.
2034///
2035/// Implements anti-join semantics for negated IS-refs (`n IS NOT rule`): keeps only
2036/// rows whose subject VID is NOT present in the negated rule's fully-converged facts.
2037pub fn apply_anti_join(
2038    batches: Vec<RecordBatch>,
2039    neg_facts: &[RecordBatch],
2040    left_col: &str,
2041    right_col: &str,
2042) -> datafusion::error::Result<Vec<RecordBatch>> {
2043    use arrow::compute::filter_record_batch;
2044    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2045
2046    // Collect right-side VIDs from the negated rule's derived facts.
2047    let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2048    for batch in neg_facts {
2049        let Ok(idx) = batch.schema().index_of(right_col) else {
2050            continue;
2051        };
2052        let arr = batch.column(idx);
2053        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2054            continue;
2055        };
2056        for i in 0..vids.len() {
2057            if !vids.is_null(i) {
2058                banned.insert(vids.value(i));
2059            }
2060        }
2061    }
2062
2063    if banned.is_empty() {
2064        return Ok(batches);
2065    }
2066
2067    // Filter body batches: keep rows where left_col NOT IN banned.
2068    let mut result = Vec::new();
2069    for batch in batches {
2070        let Ok(idx) = batch.schema().index_of(left_col) else {
2071            result.push(batch);
2072            continue;
2073        };
2074        let arr = batch.column(idx);
2075        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2076            result.push(batch);
2077            continue;
2078        };
2079        let keep: Vec<bool> = (0..vids.len())
2080            .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2081            .collect();
2082        let keep_arr = BooleanArray::from(keep);
2083        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2084        if filtered.num_rows() > 0 {
2085            result.push(filtered);
2086        }
2087    }
2088    Ok(result)
2089}
2090
2091/// Probabilistic complement for negated IS-refs targeting PROB rules.
2092///
2093/// Instead of filtering out matching VIDs (anti-join), this adds a complement column
2094/// `__prob_complement_{rule_name}` with value `1 - p` for each matching VID, and `1.0`
2095/// for VIDs not present in the negated rule's facts.
2096///
2097/// This implements the probabilistic complement semantics: `IS NOT risk` on a PROB rule
2098/// yields the probability that the entity is NOT risky.
2099pub fn apply_prob_complement(
2100    batches: Vec<RecordBatch>,
2101    neg_facts: &[RecordBatch],
2102    left_col: &str,
2103    right_col: &str,
2104    prob_col: &str,
2105    complement_col_name: &str,
2106) -> datafusion::error::Result<Vec<RecordBatch>> {
2107    use arrow_array::{Array as _, Float64Array, UInt64Array};
2108
2109    // Build VID → probability lookup from negative facts
2110    let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
2111    for batch in neg_facts {
2112        let Ok(vid_idx) = batch.schema().index_of(right_col) else {
2113            continue;
2114        };
2115        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2116            continue;
2117        };
2118        let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
2119            continue;
2120        };
2121        let prob_arr = batch.column(prob_idx);
2122        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2123        for i in 0..vids.len() {
2124            if !vids.is_null(i) {
2125                let p = probs
2126                    .and_then(|arr| {
2127                        if arr.is_null(i) {
2128                            None
2129                        } else {
2130                            Some(arr.value(i))
2131                        }
2132                    })
2133                    .unwrap_or(0.0);
2134                // If multiple facts for same VID, use noisy-OR combination:
2135                // combined = 1 - (1 - existing) * (1 - new)
2136                prob_map
2137                    .entry(vids.value(i))
2138                    .and_modify(|existing| {
2139                        *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2140                    })
2141                    .or_insert(p);
2142            }
2143        }
2144    }
2145
2146    // Add complement column to each batch
2147    let mut result = Vec::new();
2148    for batch in batches {
2149        let Ok(idx) = batch.schema().index_of(left_col) else {
2150            result.push(batch);
2151            continue;
2152        };
2153        let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
2154            result.push(batch);
2155            continue;
2156        };
2157
2158        // Compute complement values: 1 - p for matched VIDs, 1.0 for absent
2159        let complements: Vec<f64> = (0..vids.len())
2160            .map(|i| {
2161                if vids.is_null(i) {
2162                    1.0
2163                } else {
2164                    let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
2165                    1.0 - p
2166                }
2167            })
2168            .collect();
2169
2170        let complement_arr = Float64Array::from(complements);
2171
2172        // Add the complement column to the batch
2173        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2174        columns.push(std::sync::Arc::new(complement_arr));
2175
2176        let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
2177            batch.schema().fields().iter().cloned().collect();
2178        fields.push(std::sync::Arc::new(arrow_schema::Field::new(
2179            complement_col_name,
2180            arrow_schema::DataType::Float64,
2181            true,
2182        )));
2183
2184        let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2185        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2186        result.push(new_batch);
2187    }
2188    Ok(result)
2189}
2190
2191/// Probabilistic complement for composite (multi-column) join keys.
2192///
2193/// Builds a composite key from all `join_cols` right-side columns in
2194/// `neg_facts`, maps each composite key to a probability via noisy-OR
2195/// combination, then adds a single `complement_col_name` column with
2196/// `1 - p` for matched keys and `1.0` for absent keys.
2197pub fn apply_prob_complement_composite(
2198    batches: Vec<RecordBatch>,
2199    neg_facts: &[RecordBatch],
2200    join_cols: &[(String, String)],
2201    prob_col: &str,
2202    complement_col_name: &str,
2203) -> datafusion::error::Result<Vec<RecordBatch>> {
2204    use arrow_array::{Array as _, Float64Array, UInt64Array};
2205
2206    // Build composite-key → probability lookup from negative facts.
2207    let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
2208    for batch in neg_facts {
2209        let right_indices: Vec<usize> = join_cols
2210            .iter()
2211            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2212            .collect();
2213        if right_indices.len() != join_cols.len() {
2214            continue;
2215        }
2216        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2217            continue;
2218        };
2219        let prob_arr = batch.column(prob_idx);
2220        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2221        for row in 0..batch.num_rows() {
2222            let mut key = Vec::with_capacity(right_indices.len());
2223            let mut valid = true;
2224            for &ci in &right_indices {
2225                let col = batch.column(ci);
2226                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2227                    if vids.is_null(row) {
2228                        valid = false;
2229                        break;
2230                    }
2231                    key.push(vids.value(row));
2232                } else {
2233                    valid = false;
2234                    break;
2235                }
2236            }
2237            if !valid {
2238                continue;
2239            }
2240            let p = probs
2241                .and_then(|arr| {
2242                    if arr.is_null(row) {
2243                        None
2244                    } else {
2245                        Some(arr.value(row))
2246                    }
2247                })
2248                .unwrap_or(0.0);
2249            // Noisy-OR combination for duplicate composite keys.
2250            prob_map
2251                .entry(key)
2252                .and_modify(|existing| {
2253                    *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2254                })
2255                .or_insert(p);
2256        }
2257    }
2258
2259    // Add complement column to each batch.
2260    let mut result = Vec::new();
2261    for batch in batches {
2262        let left_indices: Vec<usize> = join_cols
2263            .iter()
2264            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2265            .collect();
2266        if left_indices.len() != join_cols.len() {
2267            result.push(batch);
2268            continue;
2269        }
2270        let all_u64 = left_indices.iter().all(|&ci| {
2271            batch
2272                .column(ci)
2273                .as_any()
2274                .downcast_ref::<UInt64Array>()
2275                .is_some()
2276        });
2277        if !all_u64 {
2278            result.push(batch);
2279            continue;
2280        }
2281
2282        let complements: Vec<f64> = (0..batch.num_rows())
2283            .map(|row| {
2284                let mut key = Vec::with_capacity(left_indices.len());
2285                for &ci in &left_indices {
2286                    let vids = batch
2287                        .column(ci)
2288                        .as_any()
2289                        .downcast_ref::<UInt64Array>()
2290                        .unwrap();
2291                    if vids.is_null(row) {
2292                        return 1.0;
2293                    }
2294                    key.push(vids.value(row));
2295                }
2296                let p = prob_map.get(&key).copied().unwrap_or(0.0);
2297                1.0 - p
2298            })
2299            .collect();
2300
2301        let complement_arr = Float64Array::from(complements);
2302        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2303        columns.push(Arc::new(complement_arr));
2304
2305        let mut fields: Vec<Arc<arrow_schema::Field>> =
2306            batch.schema().fields().iter().cloned().collect();
2307        fields.push(Arc::new(arrow_schema::Field::new(
2308            complement_col_name,
2309            arrow_schema::DataType::Float64,
2310            true,
2311        )));
2312
2313        let new_schema = Arc::new(arrow_schema::Schema::new(fields));
2314        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2315        result.push(new_batch);
2316    }
2317    Ok(result)
2318}
2319
2320/// Boolean anti-join for composite (multi-column) join keys.
2321///
2322/// Builds a `HashSet<Vec<u64>>` from `neg_facts` using all right-side
2323/// columns in `join_cols`, then filters `batches` to keep only rows
2324/// whose composite left-side key is NOT in the set.
2325pub fn apply_anti_join_composite(
2326    batches: Vec<RecordBatch>,
2327    neg_facts: &[RecordBatch],
2328    join_cols: &[(String, String)],
2329) -> datafusion::error::Result<Vec<RecordBatch>> {
2330    use arrow::compute::filter_record_batch;
2331    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2332
2333    // Collect composite keys from the negated rule's derived facts.
2334    let mut banned: HashSet<Vec<u64>> = HashSet::new();
2335    for batch in neg_facts {
2336        let right_indices: Vec<usize> = join_cols
2337            .iter()
2338            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2339            .collect();
2340        if right_indices.len() != join_cols.len() {
2341            continue;
2342        }
2343        for row in 0..batch.num_rows() {
2344            let mut key = Vec::with_capacity(right_indices.len());
2345            let mut valid = true;
2346            for &ci in &right_indices {
2347                let col = batch.column(ci);
2348                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2349                    if vids.is_null(row) {
2350                        valid = false;
2351                        break;
2352                    }
2353                    key.push(vids.value(row));
2354                } else {
2355                    valid = false;
2356                    break;
2357                }
2358            }
2359            if valid {
2360                banned.insert(key);
2361            }
2362        }
2363    }
2364
2365    if banned.is_empty() {
2366        return Ok(batches);
2367    }
2368
2369    // Filter body batches: keep rows where composite left key NOT IN banned.
2370    let mut result = Vec::new();
2371    for batch in batches {
2372        let left_indices: Vec<usize> = join_cols
2373            .iter()
2374            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2375            .collect();
2376        if left_indices.len() != join_cols.len() {
2377            result.push(batch);
2378            continue;
2379        }
2380        let all_u64 = left_indices.iter().all(|&ci| {
2381            batch
2382                .column(ci)
2383                .as_any()
2384                .downcast_ref::<UInt64Array>()
2385                .is_some()
2386        });
2387        if !all_u64 {
2388            result.push(batch);
2389            continue;
2390        }
2391
2392        let keep: Vec<bool> = (0..batch.num_rows())
2393            .map(|row| {
2394                let mut key = Vec::with_capacity(left_indices.len());
2395                for &ci in &left_indices {
2396                    let vids = batch
2397                        .column(ci)
2398                        .as_any()
2399                        .downcast_ref::<UInt64Array>()
2400                        .unwrap();
2401                    if vids.is_null(row) {
2402                        return true; // null keys are never banned
2403                    }
2404                    key.push(vids.value(row));
2405                }
2406                !banned.contains(&key)
2407            })
2408            .collect();
2409        let keep_arr = BooleanArray::from(keep);
2410        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2411        if filtered.num_rows() > 0 {
2412            result.push(filtered);
2413        }
2414    }
2415    Ok(result)
2416}
2417
2418/// Multiply `__prob_complement_*` columns into the rule's PROB column and clean up.
2419///
2420/// After IS NOT probabilistic complement semantics have added `__prob_complement_*`
2421/// columns to clause results, this function:
2422/// 1. Computes the product of all complement factor columns
2423/// 2. Multiplies the product into the existing PROB column (if any)
2424/// 3. Removes the internal `__prob_complement_*` columns from the output
2425///
2426/// If the rule has no PROB column, complement columns are simply removed
2427/// (the complement information is discarded and IS NOT acts as a keep-all).
2428pub fn multiply_prob_factors(
2429    batches: Vec<RecordBatch>,
2430    prob_col: Option<&str>,
2431    complement_cols: &[String],
2432) -> datafusion::error::Result<Vec<RecordBatch>> {
2433    use arrow_array::{Array as _, Float64Array};
2434
2435    let mut result = Vec::with_capacity(batches.len());
2436
2437    for batch in batches {
2438        if batch.num_rows() == 0 {
2439            // Remove complement columns from empty batches
2440            let keep: Vec<usize> = batch
2441                .schema()
2442                .fields()
2443                .iter()
2444                .enumerate()
2445                .filter(|(_, f)| !complement_cols.contains(f.name()))
2446                .map(|(i, _)| i)
2447                .collect();
2448            let fields: Vec<_> = keep
2449                .iter()
2450                .map(|&i| batch.schema().field(i).clone())
2451                .collect();
2452            let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
2453            let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2454            result.push(
2455                RecordBatch::try_new(schema, cols).map_err(|e| {
2456                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
2457                })?,
2458            );
2459            continue;
2460        }
2461
2462        let num_rows = batch.num_rows();
2463
2464        // 1. Compute product of all complement factors
2465        let mut combined = vec![1.0f64; num_rows];
2466        for col_name in complement_cols {
2467            if let Ok(idx) = batch.schema().index_of(col_name) {
2468                let arr = batch
2469                    .column(idx)
2470                    .as_any()
2471                    .downcast_ref::<Float64Array>()
2472                    .ok_or_else(|| {
2473                        datafusion::error::DataFusionError::Internal(format!(
2474                            "Expected Float64 for complement column {col_name}"
2475                        ))
2476                    })?;
2477                for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
2478                    if !arr.is_null(i) {
2479                        *val *= arr.value(i);
2480                    }
2481                }
2482            }
2483        }
2484
2485        // 2. If there's a PROB column, multiply combined into it
2486        let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
2487            if let Ok(idx) = batch.schema().index_of(prob_name) {
2488                let arr = batch
2489                    .column(idx)
2490                    .as_any()
2491                    .downcast_ref::<Float64Array>()
2492                    .ok_or_else(|| {
2493                        datafusion::error::DataFusionError::Internal(format!(
2494                            "Expected Float64 for PROB column {prob_name}"
2495                        ))
2496                    })?;
2497                (0..num_rows)
2498                    .map(|i| {
2499                        if arr.is_null(i) {
2500                            combined[i]
2501                        } else {
2502                            arr.value(i) * combined[i]
2503                        }
2504                    })
2505                    .collect()
2506            } else {
2507                combined
2508            }
2509        } else {
2510            combined
2511        };
2512
2513        let new_prob_array: arrow_array::ArrayRef =
2514            std::sync::Arc::new(Float64Array::from(final_prob));
2515
2516        // 3. Build output: replace PROB column, remove complement columns
2517        let mut fields = Vec::new();
2518        let mut columns = Vec::new();
2519
2520        for (idx, field) in batch.schema().fields().iter().enumerate() {
2521            if complement_cols.contains(field.name()) {
2522                continue;
2523            }
2524            if prob_col.is_some_and(|p| field.name() == p) {
2525                fields.push(field.clone());
2526                columns.push(new_prob_array.clone());
2527            } else {
2528                fields.push(field.clone());
2529                columns.push(batch.column(idx).clone());
2530            }
2531        }
2532
2533        let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2534        result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
2535    }
2536
2537    Ok(result)
2538}
2539
2540/// Update derived scan handles before evaluating a rule's clause bodies.
2541///
2542/// For self-references: inject delta (semi-naive optimization).
2543/// For cross-references: inject full facts.
2544fn update_derived_scan_handles(
2545    registry: &DerivedScanRegistry,
2546    states: &[FixpointState],
2547    current_rule_idx: usize,
2548    rules: &[FixpointRulePlan],
2549) {
2550    let current_rule_name = &rules[current_rule_idx].name;
2551
2552    for entry in &registry.entries {
2553        // Find the state for this entry's rule
2554        let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
2555        let Some(source_idx) = source_state_idx else {
2556            continue;
2557        };
2558
2559        let is_self = entry.rule_name == *current_rule_name;
2560        let data = if is_self {
2561            // Self-ref: inject delta for semi-naive
2562            states[source_idx].all_delta().to_vec()
2563        } else {
2564            // Cross-ref: inject full facts
2565            states[source_idx].all_facts().to_vec()
2566        };
2567
2568        // If empty, write an empty batch so the scan returns zero rows
2569        let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
2570            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
2571        } else {
2572            data
2573        };
2574
2575        let mut guard = entry.data.write();
2576        *guard = data;
2577    }
2578}
2579
2580// ---------------------------------------------------------------------------
2581// DerivedScanExec — physical plan that reads from shared data at execution time
2582// ---------------------------------------------------------------------------
2583
2584/// Physical plan for `LocyDerivedScan` that reads from a shared `Arc<RwLock>` at
2585/// execution time (not at plan creation time).
2586///
2587/// This is critical for fixpoint iteration: the data handle is updated between
2588/// iterations, and each re-execution of the subplan must read the latest data.
2589pub struct DerivedScanExec {
2590    data: Arc<RwLock<Vec<RecordBatch>>>,
2591    schema: SchemaRef,
2592    properties: PlanProperties,
2593}
2594
2595impl DerivedScanExec {
2596    pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
2597        let properties = compute_plan_properties(Arc::clone(&schema));
2598        Self {
2599            data,
2600            schema,
2601            properties,
2602        }
2603    }
2604}
2605
2606impl fmt::Debug for DerivedScanExec {
2607    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2608        f.debug_struct("DerivedScanExec")
2609            .field("schema", &self.schema)
2610            .finish()
2611    }
2612}
2613
2614impl DisplayAs for DerivedScanExec {
2615    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2616        write!(f, "DerivedScanExec")
2617    }
2618}
2619
2620impl ExecutionPlan for DerivedScanExec {
2621    fn name(&self) -> &str {
2622        "DerivedScanExec"
2623    }
2624    fn as_any(&self) -> &dyn Any {
2625        self
2626    }
2627    fn schema(&self) -> SchemaRef {
2628        Arc::clone(&self.schema)
2629    }
2630    fn properties(&self) -> &PlanProperties {
2631        &self.properties
2632    }
2633    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2634        vec![]
2635    }
2636    fn with_new_children(
2637        self: Arc<Self>,
2638        _children: Vec<Arc<dyn ExecutionPlan>>,
2639    ) -> DFResult<Arc<dyn ExecutionPlan>> {
2640        Ok(self)
2641    }
2642    fn execute(
2643        &self,
2644        _partition: usize,
2645        _context: Arc<TaskContext>,
2646    ) -> DFResult<SendableRecordBatchStream> {
2647        let batches = {
2648            let guard = self.data.read();
2649            if guard.is_empty() {
2650                vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
2651            } else {
2652                guard.clone()
2653            }
2654        };
2655        Ok(Box::pin(MemoryStream::try_new(
2656            batches,
2657            Arc::clone(&self.schema),
2658            None,
2659        )?))
2660    }
2661}
2662
2663// ---------------------------------------------------------------------------
2664// InMemoryExec — wrapper to feed Vec<RecordBatch> into operator chains
2665// ---------------------------------------------------------------------------
2666
2667/// Simple in-memory execution plan that serves pre-computed batches.
2668///
2669/// Used internally to feed fixpoint results into post-fixpoint operator chains
2670/// (FOLD, BEST BY). Not exported — only used within this module.
2671struct InMemoryExec {
2672    batches: Vec<RecordBatch>,
2673    schema: SchemaRef,
2674    properties: PlanProperties,
2675}
2676
2677impl InMemoryExec {
2678    fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
2679        let properties = compute_plan_properties(Arc::clone(&schema));
2680        Self {
2681            batches,
2682            schema,
2683            properties,
2684        }
2685    }
2686}
2687
2688impl fmt::Debug for InMemoryExec {
2689    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2690        f.debug_struct("InMemoryExec")
2691            .field("num_batches", &self.batches.len())
2692            .field("schema", &self.schema)
2693            .finish()
2694    }
2695}
2696
2697impl DisplayAs for InMemoryExec {
2698    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2699        write!(f, "InMemoryExec: batches={}", self.batches.len())
2700    }
2701}
2702
2703impl ExecutionPlan for InMemoryExec {
2704    fn name(&self) -> &str {
2705        "InMemoryExec"
2706    }
2707    fn as_any(&self) -> &dyn Any {
2708        self
2709    }
2710    fn schema(&self) -> SchemaRef {
2711        Arc::clone(&self.schema)
2712    }
2713    fn properties(&self) -> &PlanProperties {
2714        &self.properties
2715    }
2716    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2717        vec![]
2718    }
2719    fn with_new_children(
2720        self: Arc<Self>,
2721        _children: Vec<Arc<dyn ExecutionPlan>>,
2722    ) -> DFResult<Arc<dyn ExecutionPlan>> {
2723        Ok(self)
2724    }
2725    fn execute(
2726        &self,
2727        _partition: usize,
2728        _context: Arc<TaskContext>,
2729    ) -> DFResult<SendableRecordBatchStream> {
2730        Ok(Box::pin(MemoryStream::try_new(
2731            self.batches.clone(),
2732            Arc::clone(&self.schema),
2733            None,
2734        )?))
2735    }
2736}
2737
2738// ---------------------------------------------------------------------------
2739// Post-fixpoint chain — FOLD and BEST BY on converged facts
2740// ---------------------------------------------------------------------------
2741
2742/// Apply post-FOLD WHERE (HAVING) filter to aggregated batches.
2743///
2744/// Converts each Cypher HAVING expression to a DataFusion physical expression
2745/// via `cypher_expr_to_df` → type coercion → `create_physical_expr`, evaluates
2746/// against the FOLD output, and keeps only rows where all conditions hold.
2747fn apply_having_filter(
2748    batches: Vec<RecordBatch>,
2749    having_exprs: &[Expr],
2750    schema: &SchemaRef,
2751) -> DFResult<Vec<RecordBatch>> {
2752    use arrow::compute::{and, filter_record_batch};
2753    use arrow_array::BooleanArray;
2754    use datafusion::common::DFSchema;
2755    use datafusion::logical_expr::LogicalPlanBuilder;
2756    use datafusion::optimizer::AnalyzerRule;
2757    use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
2758    use datafusion::physical_expr::create_physical_expr;
2759    use datafusion::prelude::SessionContext;
2760
2761    if batches.is_empty() {
2762        return Ok(batches);
2763    }
2764
2765    // Build DFSchema from the FOLD output Arrow schema.
2766    let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
2767        datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
2768    })?;
2769
2770    let ctx = SessionContext::new();
2771    let state = ctx.state();
2772    let config = state.config_options().clone();
2773    let props = state.execution_props();
2774
2775    // Cypher Expr → DataFusion DfExpr → type-coerced DfExpr → PhysicalExpr.
2776    //
2777    // Type coercion is needed because FOLD aggregates produce Float64 (SUM,
2778    // AVG) or Int64 (COUNT), and literal comparisons like `total >= 100`
2779    // may mix Float64 columns with Int64 literals.
2780    let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
2781        .iter()
2782        .map(|expr| {
2783            let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
2784                datafusion::common::DataFusionError::Internal(format!(
2785                    "HAVING expression conversion: {e}"
2786                ))
2787            })?;
2788
2789            // Run DataFusion's type coercion by wrapping in a Filter plan,
2790            // applying the TypeCoercion analyzer rule, then extracting the
2791            // coerced predicate.
2792            let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
2793                datafusion::logical_expr::EmptyRelation {
2794                    produce_one_row: false,
2795                    schema: Arc::new(df_schema.clone()),
2796                },
2797            );
2798            let filter_plan = LogicalPlanBuilder::from(empty)
2799                .filter(df_expr.clone())?
2800                .build()?;
2801            let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
2802                Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
2803                _ => df_expr,
2804            };
2805
2806            create_physical_expr(&coerced_expr, &df_schema, props)
2807        })
2808        .collect::<DFResult<Vec<_>>>()?;
2809
2810    let mut result = Vec::new();
2811    for batch in batches {
2812        // Evaluate each condition and AND the boolean masks.
2813        let mut mask: Option<BooleanArray> = None;
2814        for phys_expr in &physical_exprs {
2815            let value = phys_expr.evaluate(&batch)?;
2816            let arr = value.into_array(batch.num_rows())?;
2817            let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
2818                datafusion::common::DataFusionError::Internal(
2819                    "HAVING condition must evaluate to boolean".into(),
2820                )
2821            })?;
2822            mask = Some(match mask {
2823                None => bool_arr.clone(),
2824                Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
2825            });
2826        }
2827        if let Some(ref m) = mask {
2828            let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
2829            if filtered.num_rows() > 0 {
2830                result.push(filtered);
2831            }
2832        } else {
2833            result.push(batch);
2834        }
2835    }
2836    Ok(result)
2837}
2838
2839/// Apply post-fixpoint operators (FOLD, HAVING, BEST BY, PRIORITY) to converged facts.
2840pub(crate) async fn apply_post_fixpoint_chain(
2841    facts: Vec<RecordBatch>,
2842    rule: &FixpointRulePlan,
2843    task_ctx: &Arc<TaskContext>,
2844    strict_probability_domain: bool,
2845    probability_epsilon: f64,
2846) -> DFResult<Vec<RecordBatch>> {
2847    if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
2848        return Ok(facts);
2849    }
2850
2851    // Wrap facts in InMemoryExec.
2852    // Prefer the actual batch schema (from physical execution) over the
2853    // pre-computed yield_schema, which may have wrong inferred types
2854    // (e.g. Float64 for a string property).
2855    let schema = facts
2856        .iter()
2857        .find(|b| b.num_rows() > 0)
2858        .map(|b| b.schema())
2859        .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
2860    let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema));
2861
2862    // Apply PRIORITY first — keeps only rows with max __priority per KEY group,
2863    // then strips the __priority column from output.
2864    // Must run before FOLD so that the __priority column is still present.
2865    let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
2866        let priority_schema = input.schema();
2867        let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
2868            datafusion::common::DataFusionError::Internal(
2869                "PRIORITY rule missing __priority column".to_string(),
2870            )
2871        })?;
2872        Arc::new(PriorityExec::new(
2873            input,
2874            rule.key_column_indices.clone(),
2875            priority_idx,
2876        ))
2877    } else {
2878        input
2879    };
2880
2881    // Apply FOLD
2882    let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
2883        Arc::new(FoldExec::new(
2884            current,
2885            rule.key_column_indices.clone(),
2886            rule.fold_bindings.clone(),
2887            strict_probability_domain,
2888            probability_epsilon,
2889        ))
2890    } else {
2891        current
2892    };
2893
2894    // Apply HAVING (post-FOLD WHERE filter)
2895    let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
2896        let batches = collect_all_partitions(&current, Arc::clone(task_ctx)).await?;
2897        let filtered = apply_having_filter(batches, &rule.having, &current.schema())?;
2898        if filtered.is_empty() {
2899            return Ok(filtered);
2900        }
2901        Arc::new(InMemoryExec::new(filtered, Arc::clone(&current.schema())))
2902    } else {
2903        current
2904    };
2905
2906    // Apply BEST BY
2907    let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
2908        Arc::new(BestByExec::new(
2909            current,
2910            rule.key_column_indices.clone(),
2911            rule.best_by_criteria.clone(),
2912            rule.deterministic,
2913        ))
2914    } else {
2915        current
2916    };
2917
2918    collect_all_partitions(&current, Arc::clone(task_ctx)).await
2919}
2920
2921// ---------------------------------------------------------------------------
2922// FixpointExec — DataFusion ExecutionPlan
2923// ---------------------------------------------------------------------------
2924
2925/// DataFusion `ExecutionPlan` that drives semi-naive fixpoint iteration.
2926///
2927/// Has no physical children: clause bodies are re-planned from logical plans
2928/// on each iteration (same pattern as `RecursiveCTEExec` and `GraphApplyExec`).
2929pub struct FixpointExec {
2930    rules: Vec<FixpointRulePlan>,
2931    max_iterations: usize,
2932    timeout: Duration,
2933    graph_ctx: Arc<GraphExecutionContext>,
2934    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2935    storage: Arc<StorageManager>,
2936    schema_info: Arc<UniSchema>,
2937    params: HashMap<String, Value>,
2938    derived_scan_registry: Arc<DerivedScanRegistry>,
2939    output_schema: SchemaRef,
2940    properties: PlanProperties,
2941    metrics: ExecutionPlanMetricsSet,
2942    max_derived_bytes: usize,
2943    /// Optional provenance tracker populated during fixpoint iteration.
2944    derivation_tracker: Option<Arc<ProvenanceStore>>,
2945    /// Shared slot written with per-rule iteration counts after convergence.
2946    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2947    strict_probability_domain: bool,
2948    probability_epsilon: f64,
2949    exact_probability: bool,
2950    max_bdd_variables: usize,
2951    /// Shared slot for runtime warnings collected during fixpoint iteration.
2952    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2953    /// Shared slot for groups where BDD fell back to independence mode.
2954    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2955    /// When > 0, retain at most this many proofs per fact (top-k provenance).
2956    top_k_proofs: usize,
2957}
2958
2959impl fmt::Debug for FixpointExec {
2960    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2961        f.debug_struct("FixpointExec")
2962            .field("rules_count", &self.rules.len())
2963            .field("max_iterations", &self.max_iterations)
2964            .field("timeout", &self.timeout)
2965            .field("output_schema", &self.output_schema)
2966            .field("max_derived_bytes", &self.max_derived_bytes)
2967            .finish_non_exhaustive()
2968    }
2969}
2970
2971impl FixpointExec {
2972    /// Create a new `FixpointExec`.
2973    #[expect(
2974        clippy::too_many_arguments,
2975        reason = "FixpointExec configuration needs all context"
2976    )]
2977    pub fn new(
2978        rules: Vec<FixpointRulePlan>,
2979        max_iterations: usize,
2980        timeout: Duration,
2981        graph_ctx: Arc<GraphExecutionContext>,
2982        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2983        storage: Arc<StorageManager>,
2984        schema_info: Arc<UniSchema>,
2985        params: HashMap<String, Value>,
2986        derived_scan_registry: Arc<DerivedScanRegistry>,
2987        output_schema: SchemaRef,
2988        max_derived_bytes: usize,
2989        derivation_tracker: Option<Arc<ProvenanceStore>>,
2990        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2991        strict_probability_domain: bool,
2992        probability_epsilon: f64,
2993        exact_probability: bool,
2994        max_bdd_variables: usize,
2995        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2996        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2997        top_k_proofs: usize,
2998    ) -> Self {
2999        let properties = compute_plan_properties(Arc::clone(&output_schema));
3000        Self {
3001            rules,
3002            max_iterations,
3003            timeout,
3004            graph_ctx,
3005            session_ctx,
3006            storage,
3007            schema_info,
3008            params,
3009            derived_scan_registry,
3010            output_schema,
3011            properties,
3012            metrics: ExecutionPlanMetricsSet::new(),
3013            max_derived_bytes,
3014            derivation_tracker,
3015            iteration_counts,
3016            strict_probability_domain,
3017            probability_epsilon,
3018            exact_probability,
3019            max_bdd_variables,
3020            warnings_slot,
3021            approximate_slot,
3022            top_k_proofs,
3023        }
3024    }
3025
3026    /// Returns the shared iteration counts slot for post-execution inspection.
3027    pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
3028        Arc::clone(&self.iteration_counts)
3029    }
3030}
3031
3032impl DisplayAs for FixpointExec {
3033    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3034        write!(
3035            f,
3036            "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
3037            self.rules
3038                .iter()
3039                .map(|r| r.name.as_str())
3040                .collect::<Vec<_>>()
3041                .join(", "),
3042            self.max_iterations,
3043            self.timeout,
3044        )
3045    }
3046}
3047
3048impl ExecutionPlan for FixpointExec {
3049    fn name(&self) -> &str {
3050        "FixpointExec"
3051    }
3052
3053    fn as_any(&self) -> &dyn Any {
3054        self
3055    }
3056
3057    fn schema(&self) -> SchemaRef {
3058        Arc::clone(&self.output_schema)
3059    }
3060
3061    fn properties(&self) -> &PlanProperties {
3062        &self.properties
3063    }
3064
3065    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
3066        // No physical children — clause bodies are re-planned each iteration
3067        vec![]
3068    }
3069
3070    fn with_new_children(
3071        self: Arc<Self>,
3072        children: Vec<Arc<dyn ExecutionPlan>>,
3073    ) -> DFResult<Arc<dyn ExecutionPlan>> {
3074        if !children.is_empty() {
3075            return Err(datafusion::error::DataFusionError::Plan(
3076                "FixpointExec has no children".to_string(),
3077            ));
3078        }
3079        Ok(self)
3080    }
3081
3082    fn execute(
3083        &self,
3084        partition: usize,
3085        _context: Arc<TaskContext>,
3086    ) -> DFResult<SendableRecordBatchStream> {
3087        let metrics = BaselineMetrics::new(&self.metrics, partition);
3088
3089        // Clone all fields for the async closure
3090        let rules = self
3091            .rules
3092            .iter()
3093            .map(|r| {
3094                // We need to clone the FixpointRulePlan, but it contains LogicalPlan
3095                // which doesn't implement Clone traditionally. However, our LogicalPlan
3096                // does implement Clone since it's an enum.
3097                FixpointRulePlan {
3098                    name: r.name.clone(),
3099                    clauses: r
3100                        .clauses
3101                        .iter()
3102                        .map(|c| FixpointClausePlan {
3103                            body_logical: c.body_logical.clone(),
3104                            is_ref_bindings: c.is_ref_bindings.clone(),
3105                            priority: c.priority,
3106                            along_bindings: c.along_bindings.clone(),
3107                        })
3108                        .collect(),
3109                    yield_schema: Arc::clone(&r.yield_schema),
3110                    key_column_indices: r.key_column_indices.clone(),
3111                    priority: r.priority,
3112                    has_fold: r.has_fold,
3113                    fold_bindings: r.fold_bindings.clone(),
3114                    having: r.having.clone(),
3115                    has_best_by: r.has_best_by,
3116                    best_by_criteria: r.best_by_criteria.clone(),
3117                    has_priority: r.has_priority,
3118                    deterministic: r.deterministic,
3119                    prob_column_name: r.prob_column_name.clone(),
3120                }
3121            })
3122            .collect();
3123
3124        let max_iterations = self.max_iterations;
3125        let timeout = self.timeout;
3126        let graph_ctx = Arc::clone(&self.graph_ctx);
3127        let session_ctx = Arc::clone(&self.session_ctx);
3128        let storage = Arc::clone(&self.storage);
3129        let schema_info = Arc::clone(&self.schema_info);
3130        let params = self.params.clone();
3131        let registry = Arc::clone(&self.derived_scan_registry);
3132        let output_schema = Arc::clone(&self.output_schema);
3133        let max_derived_bytes = self.max_derived_bytes;
3134        let derivation_tracker = self.derivation_tracker.clone();
3135        let iteration_counts = Arc::clone(&self.iteration_counts);
3136        let strict_probability_domain = self.strict_probability_domain;
3137        let probability_epsilon = self.probability_epsilon;
3138        let exact_probability = self.exact_probability;
3139        let max_bdd_variables = self.max_bdd_variables;
3140        let warnings_slot = Arc::clone(&self.warnings_slot);
3141        let approximate_slot = Arc::clone(&self.approximate_slot);
3142        let top_k_proofs = self.top_k_proofs;
3143
3144        let fut = async move {
3145            run_fixpoint_loop(
3146                rules,
3147                max_iterations,
3148                timeout,
3149                graph_ctx,
3150                session_ctx,
3151                storage,
3152                schema_info,
3153                params,
3154                registry,
3155                output_schema,
3156                max_derived_bytes,
3157                derivation_tracker,
3158                iteration_counts,
3159                strict_probability_domain,
3160                probability_epsilon,
3161                exact_probability,
3162                max_bdd_variables,
3163                warnings_slot,
3164                approximate_slot,
3165                top_k_proofs,
3166            )
3167            .await
3168        };
3169
3170        Ok(Box::pin(FixpointStream {
3171            state: FixpointStreamState::Running(Box::pin(fut)),
3172            schema: Arc::clone(&self.output_schema),
3173            metrics,
3174        }))
3175    }
3176
3177    fn metrics(&self) -> Option<MetricsSet> {
3178        Some(self.metrics.clone_inner())
3179    }
3180}
3181
3182// ---------------------------------------------------------------------------
3183// FixpointStream — async state machine for streaming results
3184// ---------------------------------------------------------------------------
3185
3186enum FixpointStreamState {
3187    /// Fixpoint loop is running.
3188    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
3189    /// Emitting accumulated result batches one at a time.
3190    Emitting(Vec<RecordBatch>, usize),
3191    /// All batches emitted.
3192    Done,
3193}
3194
3195struct FixpointStream {
3196    state: FixpointStreamState,
3197    schema: SchemaRef,
3198    metrics: BaselineMetrics,
3199}
3200
3201impl Stream for FixpointStream {
3202    type Item = DFResult<RecordBatch>;
3203
3204    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3205        let this = self.get_mut();
3206        loop {
3207            match &mut this.state {
3208                FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
3209                    Poll::Ready(Ok(batches)) => {
3210                        if batches.is_empty() {
3211                            this.state = FixpointStreamState::Done;
3212                            return Poll::Ready(None);
3213                        }
3214                        this.state = FixpointStreamState::Emitting(batches, 0);
3215                        // Loop to emit first batch
3216                    }
3217                    Poll::Ready(Err(e)) => {
3218                        this.state = FixpointStreamState::Done;
3219                        return Poll::Ready(Some(Err(e)));
3220                    }
3221                    Poll::Pending => return Poll::Pending,
3222                },
3223                FixpointStreamState::Emitting(batches, idx) => {
3224                    if *idx >= batches.len() {
3225                        this.state = FixpointStreamState::Done;
3226                        return Poll::Ready(None);
3227                    }
3228                    let batch = batches[*idx].clone();
3229                    *idx += 1;
3230                    this.metrics.record_output(batch.num_rows());
3231                    return Poll::Ready(Some(Ok(batch)));
3232                }
3233                FixpointStreamState::Done => return Poll::Ready(None),
3234            }
3235        }
3236    }
3237}
3238
3239impl RecordBatchStream for FixpointStream {
3240    fn schema(&self) -> SchemaRef {
3241        Arc::clone(&self.schema)
3242    }
3243}
3244
3245// ---------------------------------------------------------------------------
3246// Unit tests
3247// ---------------------------------------------------------------------------
3248
3249#[cfg(test)]
3250mod tests {
3251    use super::*;
3252    use arrow_array::{Float64Array, Int64Array, StringArray};
3253    use arrow_schema::{DataType, Field, Schema};
3254
3255    fn test_schema() -> SchemaRef {
3256        Arc::new(Schema::new(vec![
3257            Field::new("name", DataType::Utf8, true),
3258            Field::new("value", DataType::Int64, true),
3259        ]))
3260    }
3261
3262    fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
3263        RecordBatch::try_new(
3264            test_schema(),
3265            vec![
3266                Arc::new(StringArray::from(
3267                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3268                )),
3269                Arc::new(Int64Array::from(values.to_vec())),
3270            ],
3271        )
3272        .unwrap()
3273    }
3274
3275    // --- FixpointState dedup tests ---
3276
3277    #[tokio::test]
3278    async fn test_fixpoint_state_empty_facts_adds_all() {
3279        let schema = test_schema();
3280        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3281
3282        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3283        let changed = state.merge_delta(vec![batch], None).await.unwrap();
3284
3285        assert!(changed);
3286        assert_eq!(state.all_facts().len(), 1);
3287        assert_eq!(state.all_facts()[0].num_rows(), 3);
3288        assert_eq!(state.all_delta().len(), 1);
3289        assert_eq!(state.all_delta()[0].num_rows(), 3);
3290    }
3291
3292    #[tokio::test]
3293    async fn test_fixpoint_state_exact_duplicates_excluded() {
3294        let schema = test_schema();
3295        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3296
3297        let batch1 = make_batch(&["a", "b"], &[1, 2]);
3298        state.merge_delta(vec![batch1], None).await.unwrap();
3299
3300        // Same rows again
3301        let batch2 = make_batch(&["a", "b"], &[1, 2]);
3302        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3303        assert!(!changed);
3304        assert!(
3305            state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
3306        );
3307    }
3308
3309    #[tokio::test]
3310    async fn test_fixpoint_state_partial_overlap() {
3311        let schema = test_schema();
3312        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3313
3314        let batch1 = make_batch(&["a", "b"], &[1, 2]);
3315        state.merge_delta(vec![batch1], None).await.unwrap();
3316
3317        // "a":1 is duplicate, "c":3 is new
3318        let batch2 = make_batch(&["a", "c"], &[1, 3]);
3319        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3320        assert!(changed);
3321
3322        // Delta should have only "c":3
3323        let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
3324        assert_eq!(delta_rows, 1);
3325
3326        // Total facts: a:1, b:2, c:3
3327        let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
3328        assert_eq!(total_rows, 3);
3329    }
3330
3331    #[tokio::test]
3332    async fn test_fixpoint_state_convergence() {
3333        let schema = test_schema();
3334        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3335
3336        let batch = make_batch(&["a"], &[1]);
3337        state.merge_delta(vec![batch], None).await.unwrap();
3338
3339        // Empty candidates → converged
3340        let changed = state.merge_delta(vec![], None).await.unwrap();
3341        assert!(!changed);
3342        assert!(state.is_converged());
3343    }
3344
3345    // --- RowDedupState tests ---
3346
3347    #[test]
3348    fn test_row_dedup_persistent_across_calls() {
3349        // RowDedupState should remember rows from the first call so the second
3350        // call does not re-accept them (O(M) per iteration, no facts re-scan).
3351        let schema = test_schema();
3352        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3353
3354        let batch1 = make_batch(&["a", "b"], &[1, 2]);
3355        let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
3356        // First call: both rows are new.
3357        let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
3358        assert_eq!(rows1, 2);
3359
3360        // Second call with same rows: seen set already has them → empty delta.
3361        let batch2 = make_batch(&["a", "b"], &[1, 2]);
3362        let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
3363        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3364        assert_eq!(rows2, 0);
3365
3366        // Third call with one old + one new: only the new row is returned.
3367        let batch3 = make_batch(&["a", "c"], &[1, 3]);
3368        let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
3369        let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
3370        assert_eq!(rows3, 1);
3371    }
3372
3373    #[test]
3374    fn test_row_dedup_null_handling() {
3375        use arrow_array::StringArray;
3376        use arrow_schema::{DataType, Field, Schema};
3377
3378        let schema: SchemaRef = Arc::new(Schema::new(vec![
3379            Field::new("a", DataType::Utf8, true),
3380            Field::new("b", DataType::Int64, true),
3381        ]));
3382        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3383
3384        // Two rows: (NULL, 1) and (NULL, 1) — same NULLs → duplicate.
3385        let batch_nulls = RecordBatch::try_new(
3386            Arc::clone(&schema),
3387            vec![
3388                Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
3389                Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
3390            ],
3391        )
3392        .unwrap();
3393        let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
3394        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3395        assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
3396
3397        // (NULL, 2) — NULL in same col but different non-null col → distinct.
3398        let batch_diff = RecordBatch::try_new(
3399            Arc::clone(&schema),
3400            vec![
3401                Arc::new(StringArray::from(vec![None::<&str>])),
3402                Arc::new(arrow_array::Int64Array::from(vec![2i64])),
3403            ],
3404        )
3405        .unwrap();
3406        let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
3407        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3408        assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
3409    }
3410
3411    #[test]
3412    fn test_row_dedup_within_candidate_dedup() {
3413        // Duplicate rows within a single candidate batch should be collapsed to one.
3414        let schema = test_schema();
3415        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3416
3417        // Batch with three rows: a:1, a:1, b:2 — "a:1" appears twice.
3418        let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
3419        let delta = rd.compute_delta(&[batch], &schema).unwrap();
3420        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3421        assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
3422    }
3423
3424    // --- Float rounding tests ---
3425
3426    #[test]
3427    fn test_round_float_columns_near_duplicates() {
3428        let schema = Arc::new(Schema::new(vec![
3429            Field::new("name", DataType::Utf8, true),
3430            Field::new("dist", DataType::Float64, true),
3431        ]));
3432        let batch = RecordBatch::try_new(
3433            Arc::clone(&schema),
3434            vec![
3435                Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
3436                Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
3437            ],
3438        )
3439        .unwrap();
3440
3441        let rounded = round_float_columns(&[batch]);
3442        assert_eq!(rounded.len(), 1);
3443        let col = rounded[0]
3444            .column(1)
3445            .as_any()
3446            .downcast_ref::<Float64Array>()
3447            .unwrap();
3448        // Both should round to same value
3449        assert_eq!(col.value(0), col.value(1));
3450    }
3451
3452    // --- DerivedScanRegistry tests ---
3453
3454    #[test]
3455    fn test_registry_write_read_round_trip() {
3456        let schema = test_schema();
3457        let data = Arc::new(RwLock::new(Vec::new()));
3458        let mut reg = DerivedScanRegistry::new();
3459        reg.add(DerivedScanEntry {
3460            scan_index: 0,
3461            rule_name: "reachable".into(),
3462            is_self_ref: true,
3463            data: Arc::clone(&data),
3464            schema: Arc::clone(&schema),
3465        });
3466
3467        let batch = make_batch(&["x"], &[42]);
3468        reg.write_data(0, vec![batch.clone()]);
3469
3470        let entry = reg.get(0).unwrap();
3471        let guard = entry.data.read();
3472        assert_eq!(guard.len(), 1);
3473        assert_eq!(guard[0].num_rows(), 1);
3474    }
3475
3476    #[test]
3477    fn test_registry_entries_for_rule() {
3478        let schema = test_schema();
3479        let mut reg = DerivedScanRegistry::new();
3480        reg.add(DerivedScanEntry {
3481            scan_index: 0,
3482            rule_name: "r1".into(),
3483            is_self_ref: true,
3484            data: Arc::new(RwLock::new(Vec::new())),
3485            schema: Arc::clone(&schema),
3486        });
3487        reg.add(DerivedScanEntry {
3488            scan_index: 1,
3489            rule_name: "r2".into(),
3490            is_self_ref: false,
3491            data: Arc::new(RwLock::new(Vec::new())),
3492            schema: Arc::clone(&schema),
3493        });
3494        reg.add(DerivedScanEntry {
3495            scan_index: 2,
3496            rule_name: "r1".into(),
3497            is_self_ref: false,
3498            data: Arc::new(RwLock::new(Vec::new())),
3499            schema: Arc::clone(&schema),
3500        });
3501
3502        assert_eq!(reg.entries_for_rule("r1").len(), 2);
3503        assert_eq!(reg.entries_for_rule("r2").len(), 1);
3504        assert_eq!(reg.entries_for_rule("r3").len(), 0);
3505    }
3506
3507    // --- MonotonicAggState tests ---
3508
3509    #[test]
3510    fn test_monotonic_agg_update_and_stability() {
3511        use crate::query::df_graph::locy_fold::FoldAggKind;
3512
3513        let bindings = vec![MonotonicFoldBinding {
3514            fold_name: "total".into(),
3515            kind: FoldAggKind::Sum,
3516            input_col_index: 1,
3517            input_col_name: None,
3518        }];
3519        let mut agg = MonotonicAggState::new(bindings);
3520
3521        // First update
3522        let batch = make_batch(&["a"], &[10]);
3523        agg.snapshot();
3524        let changed = agg.update(&[0], &[batch], false).unwrap();
3525        assert!(changed);
3526        assert!(!agg.is_stable()); // changed since snapshot
3527
3528        // Snapshot and check stability with no new data
3529        agg.snapshot();
3530        let changed = agg.update(&[0], &[], false).unwrap();
3531        assert!(!changed);
3532        assert!(agg.is_stable());
3533    }
3534
3535    // --- Memory limit test ---
3536
3537    #[tokio::test]
3538    async fn test_memory_limit_exceeded() {
3539        let schema = test_schema();
3540        // Set a tiny limit
3541        let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
3542
3543        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3544        let result = state.merge_delta(vec![batch], None).await;
3545        assert!(result.is_err());
3546        let err = result.unwrap_err().to_string();
3547        assert!(err.contains("memory limit"), "Error was: {}", err);
3548    }
3549
3550    // --- FixpointStream lifecycle test ---
3551
3552    #[tokio::test]
3553    async fn test_fixpoint_stream_emitting() {
3554        use futures::StreamExt;
3555
3556        let schema = test_schema();
3557        let batch1 = make_batch(&["a"], &[1]);
3558        let batch2 = make_batch(&["b"], &[2]);
3559
3560        let metrics = ExecutionPlanMetricsSet::new();
3561        let baseline = BaselineMetrics::new(&metrics, 0);
3562
3563        let mut stream = FixpointStream {
3564            state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
3565            schema,
3566            metrics: baseline,
3567        };
3568
3569        let stream = Pin::new(&mut stream);
3570        let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
3571
3572        assert_eq!(batches.len(), 2);
3573        assert_eq!(batches[0].num_rows(), 1);
3574        assert_eq!(batches[1].num_rows(), 1);
3575    }
3576
3577    // ── MonotonicAggState MNOR/MPROD tests ──────────────────────────────
3578
3579    fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
3580        let schema = Arc::new(Schema::new(vec![
3581            Field::new("name", DataType::Utf8, true),
3582            Field::new("value", DataType::Float64, true),
3583        ]));
3584        RecordBatch::try_new(
3585            schema,
3586            vec![
3587                Arc::new(StringArray::from(
3588                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3589                )),
3590                Arc::new(Float64Array::from(values.to_vec())),
3591            ],
3592        )
3593        .unwrap()
3594    }
3595
3596    fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
3597        use crate::query::df_graph::locy_fold::FoldAggKind;
3598        vec![MonotonicFoldBinding {
3599            fold_name: "prob".into(),
3600            kind: FoldAggKind::Nor,
3601            input_col_index: 1,
3602            input_col_name: None,
3603        }]
3604    }
3605
3606    fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
3607        use crate::query::df_graph::locy_fold::FoldAggKind;
3608        vec![MonotonicFoldBinding {
3609            fold_name: "prob".into(),
3610            kind: FoldAggKind::Prod,
3611            input_col_index: 1,
3612            input_col_name: None,
3613        }]
3614    }
3615
3616    fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
3617        (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
3618    }
3619
3620    #[test]
3621    fn test_monotonic_nor_first_update() {
3622        let mut agg = MonotonicAggState::new(make_nor_binding());
3623        let batch = make_f64_batch(&["a"], &[0.3]);
3624        let changed = agg.update(&[0], &[batch], false).unwrap();
3625        assert!(changed);
3626        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3627        assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
3628    }
3629
3630    #[test]
3631    fn test_monotonic_nor_two_updates() {
3632        // Incremental NOR: acc = 1-(1-0.3)(1-0.5) = 0.65
3633        let mut agg = MonotonicAggState::new(make_nor_binding());
3634        let batch1 = make_f64_batch(&["a"], &[0.3]);
3635        agg.update(&[0], &[batch1], false).unwrap();
3636        let batch2 = make_f64_batch(&["a"], &[0.5]);
3637        agg.update(&[0], &[batch2], false).unwrap();
3638        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3639        assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
3640    }
3641
3642    #[test]
3643    fn test_monotonic_prod_first_update() {
3644        let mut agg = MonotonicAggState::new(make_prod_binding());
3645        let batch = make_f64_batch(&["a"], &[0.6]);
3646        let changed = agg.update(&[0], &[batch], false).unwrap();
3647        assert!(changed);
3648        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3649        assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
3650    }
3651
3652    #[test]
3653    fn test_monotonic_prod_two_updates() {
3654        // Incremental PROD: acc = 0.6 * 0.8 = 0.48
3655        let mut agg = MonotonicAggState::new(make_prod_binding());
3656        let batch1 = make_f64_batch(&["a"], &[0.6]);
3657        agg.update(&[0], &[batch1], false).unwrap();
3658        let batch2 = make_f64_batch(&["a"], &[0.8]);
3659        agg.update(&[0], &[batch2], false).unwrap();
3660        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3661        assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
3662    }
3663
3664    #[test]
3665    fn test_monotonic_nor_stability() {
3666        let mut agg = MonotonicAggState::new(make_nor_binding());
3667        let batch = make_f64_batch(&["a"], &[0.3]);
3668        agg.update(&[0], &[batch], false).unwrap();
3669        agg.snapshot();
3670        let changed = agg.update(&[0], &[], false).unwrap();
3671        assert!(!changed);
3672        assert!(agg.is_stable());
3673    }
3674
3675    #[test]
3676    fn test_monotonic_prod_stability() {
3677        let mut agg = MonotonicAggState::new(make_prod_binding());
3678        let batch = make_f64_batch(&["a"], &[0.6]);
3679        agg.update(&[0], &[batch], false).unwrap();
3680        agg.snapshot();
3681        let changed = agg.update(&[0], &[], false).unwrap();
3682        assert!(!changed);
3683        assert!(agg.is_stable());
3684    }
3685
3686    #[test]
3687    fn test_monotonic_nor_multi_group() {
3688        // (a,0.3),(b,0.5) then (a,0.5),(b,0.2) → a=0.65, b=0.6
3689        let mut agg = MonotonicAggState::new(make_nor_binding());
3690        let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
3691        agg.update(&[0], &[batch1], false).unwrap();
3692        let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
3693        agg.update(&[0], &[batch2], false).unwrap();
3694
3695        let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
3696        let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
3697        assert!(
3698            (val_a - 0.65).abs() < 1e-10,
3699            "expected a=0.65, got {}",
3700            val_a
3701        );
3702        assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
3703    }
3704
3705    #[test]
3706    fn test_monotonic_prod_zero_absorbing() {
3707        // Zero absorbs: once 0.0, all further updates stay 0.0
3708        let mut agg = MonotonicAggState::new(make_prod_binding());
3709        let batch1 = make_f64_batch(&["a"], &[0.5]);
3710        agg.update(&[0], &[batch1], false).unwrap();
3711        let batch2 = make_f64_batch(&["a"], &[0.0]);
3712        agg.update(&[0], &[batch2], false).unwrap();
3713
3714        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3715        assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
3716
3717        // Further updates don't change the absorbing zero
3718        agg.snapshot();
3719        let batch3 = make_f64_batch(&["a"], &[0.5]);
3720        let changed = agg.update(&[0], &[batch3], false).unwrap();
3721        assert!(!changed);
3722        assert!(agg.is_stable());
3723    }
3724
3725    #[test]
3726    fn test_monotonic_nor_clamping() {
3727        // 1.5 clamped to 1.0: acc = 1-(1-0)(1-1) = 1.0
3728        let mut agg = MonotonicAggState::new(make_nor_binding());
3729        let batch = make_f64_batch(&["a"], &[1.5]);
3730        agg.update(&[0], &[batch], false).unwrap();
3731        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3732        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3733    }
3734
3735    #[test]
3736    fn test_monotonic_nor_absorbing() {
3737        // p=1.0 absorbs: 0.3 then 1.0 → 1.0
3738        let mut agg = MonotonicAggState::new(make_nor_binding());
3739        let batch1 = make_f64_batch(&["a"], &[0.3]);
3740        agg.update(&[0], &[batch1], false).unwrap();
3741        let batch2 = make_f64_batch(&["a"], &[1.0]);
3742        agg.update(&[0], &[batch2], false).unwrap();
3743        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3744        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3745    }
3746
3747    // ── MonotonicAggState strict mode tests (Phase 5) ─────────────────────
3748
3749    #[test]
3750    fn test_monotonic_agg_strict_nor_rejects() {
3751        let mut agg = MonotonicAggState::new(make_nor_binding());
3752        let batch = make_f64_batch(&["a"], &[1.5]);
3753        let result = agg.update(&[0], &[batch], true);
3754        assert!(result.is_err());
3755        let err = result.unwrap_err().to_string();
3756        assert!(
3757            err.contains("strict_probability_domain"),
3758            "Expected strict error, got: {}",
3759            err
3760        );
3761    }
3762
3763    #[test]
3764    fn test_monotonic_agg_strict_prod_rejects() {
3765        let mut agg = MonotonicAggState::new(make_prod_binding());
3766        let batch = make_f64_batch(&["a"], &[2.0]);
3767        let result = agg.update(&[0], &[batch], true);
3768        assert!(result.is_err());
3769        let err = result.unwrap_err().to_string();
3770        assert!(
3771            err.contains("strict_probability_domain"),
3772            "Expected strict error, got: {}",
3773            err
3774        );
3775    }
3776
3777    #[test]
3778    fn test_monotonic_agg_strict_accepts_valid() {
3779        let mut agg = MonotonicAggState::new(make_nor_binding());
3780        let batch = make_f64_batch(&["a"], &[0.5]);
3781        let result = agg.update(&[0], &[batch], true);
3782        assert!(result.is_ok());
3783        let val = agg.get_accumulator(&acc_key("a")).unwrap();
3784        assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
3785    }
3786
3787    // ── Complement function unit tests (Phase 4) ──────────────────────────
3788
3789    fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
3790        use arrow_array::UInt64Array;
3791        let schema = Arc::new(Schema::new(vec![
3792            Field::new("vid", DataType::UInt64, true),
3793            Field::new("prob", DataType::Float64, true),
3794        ]));
3795        RecordBatch::try_new(
3796            schema,
3797            vec![
3798                Arc::new(UInt64Array::from(vids.to_vec())),
3799                Arc::new(Float64Array::from(probs.to_vec())),
3800            ],
3801        )
3802        .unwrap()
3803    }
3804
3805    #[test]
3806    fn test_prob_complement_basic() {
3807        // neg has VID=1 with prob=0.7 → complement=0.3; VID=2 absent → complement=1.0
3808        let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
3809        let neg = make_vid_prob_batch(&[1], &[0.7]);
3810        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3811        let result = apply_prob_complement_composite(
3812            vec![body],
3813            &[neg],
3814            &join_cols,
3815            "prob",
3816            "__complement_0",
3817        )
3818        .unwrap();
3819        assert_eq!(result.len(), 1);
3820        let batch = &result[0];
3821        let complement = batch
3822            .column_by_name("__complement_0")
3823            .unwrap()
3824            .as_any()
3825            .downcast_ref::<Float64Array>()
3826            .unwrap();
3827        // VID=1: complement = 1 - 0.7 = 0.3
3828        assert!(
3829            (complement.value(0) - 0.3).abs() < 1e-10,
3830            "expected 0.3, got {}",
3831            complement.value(0)
3832        );
3833        // VID=2: absent from neg → complement = 1.0
3834        assert!(
3835            (complement.value(1) - 1.0).abs() < 1e-10,
3836            "expected 1.0, got {}",
3837            complement.value(1)
3838        );
3839    }
3840
3841    #[test]
3842    fn test_prob_complement_noisy_or_duplicates() {
3843        // neg has VID=1 twice with prob=0.3 and prob=0.5
3844        // Combined via noisy-OR: 1-(1-0.3)(1-0.5) = 0.65
3845        // Complement = 1 - 0.65 = 0.35
3846        let body = make_vid_prob_batch(&[1], &[0.9]);
3847        let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
3848        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3849        let result = apply_prob_complement_composite(
3850            vec![body],
3851            &[neg],
3852            &join_cols,
3853            "prob",
3854            "__complement_0",
3855        )
3856        .unwrap();
3857        let batch = &result[0];
3858        let complement = batch
3859            .column_by_name("__complement_0")
3860            .unwrap()
3861            .as_any()
3862            .downcast_ref::<Float64Array>()
3863            .unwrap();
3864        assert!(
3865            (complement.value(0) - 0.35).abs() < 1e-10,
3866            "expected 0.35, got {}",
3867            complement.value(0)
3868        );
3869    }
3870
3871    #[test]
3872    fn test_prob_complement_empty_neg() {
3873        // Empty neg_facts → body passes through with complement=1.0
3874        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3875        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3876        let result =
3877            apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
3878                .unwrap();
3879        let batch = &result[0];
3880        let complement = batch
3881            .column_by_name("__complement_0")
3882            .unwrap()
3883            .as_any()
3884            .downcast_ref::<Float64Array>()
3885            .unwrap();
3886        for i in 0..2 {
3887            assert!(
3888                (complement.value(i) - 1.0).abs() < 1e-10,
3889                "row {}: expected 1.0, got {}",
3890                i,
3891                complement.value(i)
3892            );
3893        }
3894    }
3895
3896    #[test]
3897    fn test_anti_join_basic() {
3898        // body [1,2,3], neg [2] → result [1,3]
3899        use arrow_array::UInt64Array;
3900        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3901        let neg = make_vid_prob_batch(&[2], &[0.0]);
3902        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3903        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3904        assert_eq!(result.len(), 1);
3905        let batch = &result[0];
3906        assert_eq!(batch.num_rows(), 2);
3907        let vids = batch
3908            .column_by_name("vid")
3909            .unwrap()
3910            .as_any()
3911            .downcast_ref::<UInt64Array>()
3912            .unwrap();
3913        assert_eq!(vids.value(0), 1);
3914        assert_eq!(vids.value(1), 3);
3915    }
3916
3917    #[test]
3918    fn test_anti_join_empty_neg() {
3919        // Empty neg → all rows kept
3920        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3921        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3922        let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
3923        assert_eq!(result.len(), 1);
3924        assert_eq!(result[0].num_rows(), 3);
3925    }
3926
3927    #[test]
3928    fn test_anti_join_all_excluded() {
3929        // neg covers all body rows → empty result
3930        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3931        let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
3932        let join_cols = vec![("vid".to_string(), "vid".to_string())];
3933        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3934        let total: usize = result.iter().map(|b| b.num_rows()).sum();
3935        assert_eq!(total, 0);
3936    }
3937
3938    #[test]
3939    fn test_multiply_prob_single_complement() {
3940        // prob=0.8, complement=0.5 → output prob=0.4; complement col removed
3941        let body = make_vid_prob_batch(&[1], &[0.8]);
3942        // Add a complement column
3943        let complement_arr = Float64Array::from(vec![0.5]);
3944        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3945        cols.push(Arc::new(complement_arr));
3946        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3947        fields.push(Arc::new(Field::new(
3948            "__complement_0",
3949            DataType::Float64,
3950            true,
3951        )));
3952        let schema = Arc::new(Schema::new(fields));
3953        let batch = RecordBatch::try_new(schema, cols).unwrap();
3954
3955        let result =
3956            multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
3957                .unwrap();
3958        assert_eq!(result.len(), 1);
3959        let out = &result[0];
3960        // Complement column should be removed
3961        assert!(out.column_by_name("__complement_0").is_none());
3962        let prob = out
3963            .column_by_name("prob")
3964            .unwrap()
3965            .as_any()
3966            .downcast_ref::<Float64Array>()
3967            .unwrap();
3968        assert!(
3969            (prob.value(0) - 0.4).abs() < 1e-10,
3970            "expected 0.4, got {}",
3971            prob.value(0)
3972        );
3973    }
3974
3975    #[test]
3976    fn test_multiply_prob_multiple_complements() {
3977        // prob=0.8, c1=0.5, c2=0.6 → 0.8×0.5×0.6=0.24
3978        let body = make_vid_prob_batch(&[1], &[0.8]);
3979        let c1 = Float64Array::from(vec![0.5]);
3980        let c2 = Float64Array::from(vec![0.6]);
3981        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3982        cols.push(Arc::new(c1));
3983        cols.push(Arc::new(c2));
3984        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3985        fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
3986        fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
3987        let schema = Arc::new(Schema::new(fields));
3988        let batch = RecordBatch::try_new(schema, cols).unwrap();
3989
3990        let result = multiply_prob_factors(
3991            vec![batch],
3992            Some("prob"),
3993            &["__c1".to_string(), "__c2".to_string()],
3994        )
3995        .unwrap();
3996        let out = &result[0];
3997        assert!(out.column_by_name("__c1").is_none());
3998        assert!(out.column_by_name("__c2").is_none());
3999        let prob = out
4000            .column_by_name("prob")
4001            .unwrap()
4002            .as_any()
4003            .downcast_ref::<Float64Array>()
4004            .unwrap();
4005        assert!(
4006            (prob.value(0) - 0.24).abs() < 1e-10,
4007            "expected 0.24, got {}",
4008            prob.value(0)
4009        );
4010    }
4011
4012    #[test]
4013    fn test_multiply_prob_no_prob_column() {
4014        // No prob column → combined complements become the output
4015        use arrow_array::UInt64Array;
4016        let schema = Arc::new(Schema::new(vec![
4017            Field::new("vid", DataType::UInt64, true),
4018            Field::new("__c1", DataType::Float64, true),
4019        ]));
4020        let batch = RecordBatch::try_new(
4021            schema,
4022            vec![
4023                Arc::new(UInt64Array::from(vec![1u64])),
4024                Arc::new(Float64Array::from(vec![0.7])),
4025            ],
4026        )
4027        .unwrap();
4028
4029        let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
4030        let out = &result[0];
4031        // __c1 should be removed since it's a complement column
4032        assert!(out.column_by_name("__c1").is_none());
4033        // Only vid column remains
4034        assert_eq!(out.num_columns(), 1);
4035    }
4036}