Skip to main content

uni_query/query/df_graph/
locy_fixpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Fixpoint iteration operator for recursive Locy strata.
5//!
6//! `FixpointExec` drives semi-naive evaluation: it repeatedly evaluates the rules
7//! in a recursive stratum, feeding back deltas until no new facts are produced.
8
9use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11    ScalarKey, collect_all_partitions, compute_plan_properties, execute_subplan, extract_scalar_key,
12};
13use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
14use crate::query::df_graph::locy_errors::LocyRuntimeError;
15use crate::query::df_graph::locy_explain::{DerivationEntry, DerivationTracker};
16use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
17use crate::query::df_graph::locy_priority::PriorityExec;
18use crate::query::planner::LogicalPlan;
19use arrow_array::RecordBatch;
20use arrow_row::{RowConverter, SortField};
21use arrow_schema::SchemaRef;
22use datafusion::common::JoinType;
23use datafusion::common::Result as DFResult;
24use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
25use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
26use datafusion::physical_plan::memory::MemoryStream;
27use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
28use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
29use futures::Stream;
30use parking_lot::RwLock;
31use std::any::Any;
32use std::collections::{HashMap, HashSet};
33use std::fmt;
34use std::pin::Pin;
35use std::sync::{Arc, RwLock as StdRwLock};
36use std::task::{Context, Poll};
37use std::time::{Duration, Instant};
38use uni_common::Value;
39use uni_common::core::schema::Schema as UniSchema;
40use uni_store::storage::manager::StorageManager;
41
42// ---------------------------------------------------------------------------
43// DerivedScanRegistry — injection point for IS-ref data into subplans
44// ---------------------------------------------------------------------------
45
46/// A single entry in the derived scan registry.
47///
48/// Each entry corresponds to one `LocyDerivedScan` node in the logical plan tree.
49/// The `data` handle is shared with the logical plan node so that writing data here
50/// makes it visible when the subplan is re-planned and executed.
51#[derive(Debug)]
52pub struct DerivedScanEntry {
53    /// Index matching the `scan_index` in `LocyDerivedScan`.
54    pub scan_index: usize,
55    /// Name of the rule this scan reads from.
56    pub rule_name: String,
57    /// Whether this is a self-referential scan (rule references itself).
58    pub is_self_ref: bool,
59    /// Shared data handle — write batches here to inject into subplans.
60    pub data: Arc<RwLock<Vec<RecordBatch>>>,
61    /// Schema of the derived relation.
62    pub schema: SchemaRef,
63}
64
65/// Registry of derived scan handles for fixpoint iteration.
66///
67/// During fixpoint, each clause body may reference derived relations via
68/// `LocyDerivedScan` nodes. The registry maps scan indices to shared data
69/// handles so the fixpoint loop can inject delta/full facts before each
70/// iteration.
71#[derive(Debug, Default)]
72pub struct DerivedScanRegistry {
73    entries: Vec<DerivedScanEntry>,
74}
75
76impl DerivedScanRegistry {
77    /// Create a new empty registry.
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Add an entry to the registry.
83    pub fn add(&mut self, entry: DerivedScanEntry) {
84        self.entries.push(entry);
85    }
86
87    /// Get an entry by scan index.
88    pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
89        self.entries.iter().find(|e| e.scan_index == scan_index)
90    }
91
92    /// Write data into a scan entry's shared handle.
93    pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
94        if let Some(entry) = self.get(scan_index) {
95            let mut guard = entry.data.write();
96            *guard = batches;
97        }
98    }
99
100    /// Get all entries for a given rule name.
101    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
102        self.entries
103            .iter()
104            .filter(|e| e.rule_name == rule_name)
105            .collect()
106    }
107}
108
109// ---------------------------------------------------------------------------
110// MonotonicAggState — tracking monotonic aggregates across iterations
111// ---------------------------------------------------------------------------
112
113/// Monotonic aggregate binding: maps a fold name to its aggregate kind and column.
114#[derive(Debug, Clone)]
115pub struct MonotonicFoldBinding {
116    pub fold_name: String,
117    pub kind: crate::query::df_graph::locy_fold::FoldAggKind,
118    pub input_col_index: usize,
119}
120
121/// Tracks monotonic aggregate accumulators across fixpoint iterations.
122///
123/// After each iteration, accumulators are updated and compared to their previous
124/// snapshot. The fixpoint has converged (w.r.t. aggregates) when all accumulators
125/// are stable (no change between iterations).
126#[derive(Debug)]
127pub struct MonotonicAggState {
128    /// Current accumulator values keyed by (group_key, fold_name).
129    accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
130    /// Snapshot from the previous iteration for stability check.
131    prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
132    /// Bindings describing which aggregates to track.
133    bindings: Vec<MonotonicFoldBinding>,
134}
135
136impl MonotonicAggState {
137    /// Create a new monotonic aggregate state.
138    pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
139        Self {
140            accumulators: HashMap::new(),
141            prev_snapshot: HashMap::new(),
142            bindings,
143        }
144    }
145
146    /// Update accumulators with new delta batches. Returns true if any value changed.
147    pub fn update(&mut self, key_indices: &[usize], delta_batches: &[RecordBatch]) -> bool {
148        use crate::query::df_graph::locy_fold::FoldAggKind;
149
150        let mut changed = false;
151        for batch in delta_batches {
152            for row_idx in 0..batch.num_rows() {
153                let group_key = extract_scalar_key(batch, key_indices, row_idx);
154                for binding in &self.bindings {
155                    let col = batch.column(binding.input_col_index);
156                    let val = extract_f64(col.as_ref(), row_idx);
157                    if let Some(val) = val {
158                        let map_key = (group_key.clone(), binding.fold_name.clone());
159                        let entry =
160                            self.accumulators
161                                .entry(map_key)
162                                .or_insert(match binding.kind {
163                                    FoldAggKind::Sum | FoldAggKind::Count | FoldAggKind::Avg => 0.0,
164                                    FoldAggKind::Max => f64::NEG_INFINITY,
165                                    FoldAggKind::Min => f64::INFINITY,
166                                    FoldAggKind::Collect => 0.0,
167                                });
168                        let old = *entry;
169                        match binding.kind {
170                            FoldAggKind::Sum | FoldAggKind::Count => *entry += val,
171                            FoldAggKind::Max => {
172                                if val > *entry {
173                                    *entry = val;
174                                }
175                            }
176                            FoldAggKind::Min => {
177                                if val < *entry {
178                                    *entry = val;
179                                }
180                            }
181                            _ => {}
182                        }
183                        if (*entry - old).abs() > f64::EPSILON {
184                            changed = true;
185                        }
186                    }
187                }
188            }
189        }
190        changed
191    }
192
193    /// Take a snapshot of current accumulators for stability comparison.
194    pub fn snapshot(&mut self) {
195        self.prev_snapshot = self.accumulators.clone();
196    }
197
198    /// Check if accumulators are stable (no change since last snapshot).
199    pub fn is_stable(&self) -> bool {
200        if self.accumulators.len() != self.prev_snapshot.len() {
201            return false;
202        }
203        for (key, val) in &self.accumulators {
204            match self.prev_snapshot.get(key) {
205                Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
206                _ => return false,
207            }
208        }
209        true
210    }
211}
212
213/// Extract f64 value from an Arrow column at a given row index.
214fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
215    if col.is_null(row_idx) {
216        return None;
217    }
218    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
219        Some(arr.value(row_idx))
220    } else {
221        col.as_any()
222            .downcast_ref::<arrow_array::Int64Array>()
223            .map(|arr| arr.value(row_idx) as f64)
224    }
225}
226
227// ---------------------------------------------------------------------------
228// RowDedupState — Arrow RowConverter-based persistent dedup set
229// ---------------------------------------------------------------------------
230
231/// Arrow-native row deduplication using [`RowConverter`].
232///
233/// Unlike the legacy `HashSet<Vec<ScalarKey>>` approach, this struct maintains a
234/// persistent `seen` set across iterations so per-iteration cost is O(M) where M
235/// is the number of candidate rows — the full facts table is never re-scanned.
236struct RowDedupState {
237    converter: RowConverter,
238    seen: HashSet<Box<[u8]>>,
239}
240
241impl RowDedupState {
242    /// Try to build a `RowDedupState` for the given schema.
243    ///
244    /// Returns `None` if any column type is not supported by `RowConverter`
245    /// (triggers legacy fallback).
246    fn try_new(schema: &SchemaRef) -> Option<Self> {
247        let fields: Vec<SortField> = schema
248            .fields()
249            .iter()
250            .map(|f| SortField::new(f.data_type().clone()))
251            .collect();
252        match RowConverter::new(fields) {
253            Ok(converter) => Some(Self {
254                converter,
255                seen: HashSet::new(),
256            }),
257            Err(e) => {
258                tracing::warn!(
259                    "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
260                    e
261                );
262                None
263            }
264        }
265    }
266
267    /// Filter `candidates` to only rows not yet seen, updating the persistent set.
268    ///
269    /// Both cross-iteration dedup (rows already accepted in prior iterations) and
270    /// within-batch dedup (duplicate rows in a single candidate batch) are handled
271    /// in a single pass.
272    fn compute_delta(
273        &mut self,
274        candidates: &[RecordBatch],
275        schema: &SchemaRef,
276    ) -> DFResult<Vec<RecordBatch>> {
277        let mut delta_batches = Vec::new();
278        for batch in candidates {
279            if batch.num_rows() == 0 {
280                continue;
281            }
282
283            // Vectorized encoding of all rows in this batch.
284            let arrays: Vec<_> = batch.columns().to_vec();
285            let rows = self
286                .converter
287                .convert_columns(&arrays)
288                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
289
290            // One pass: check+insert into persistent seen set.
291            let mut keep = Vec::with_capacity(batch.num_rows());
292            for row_idx in 0..batch.num_rows() {
293                let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
294                keep.push(self.seen.insert(row_bytes));
295            }
296
297            let keep_mask = arrow_array::BooleanArray::from(keep);
298            let new_cols = batch
299                .columns()
300                .iter()
301                .map(|col| {
302                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
303                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
304                    })
305                })
306                .collect::<DFResult<Vec<_>>>()?;
307
308            if new_cols.first().is_some_and(|c| !c.is_empty()) {
309                let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
310                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
311                })?;
312                delta_batches.push(filtered);
313            }
314        }
315        Ok(delta_batches)
316    }
317}
318
319// ---------------------------------------------------------------------------
320// FixpointState — per-rule delta tracking during fixpoint iteration
321// ---------------------------------------------------------------------------
322
323/// Per-rule state for fixpoint iteration.
324///
325/// Tracks accumulated facts and the delta (new facts from the latest iteration).
326/// Deduplication uses Arrow [`RowConverter`] with a persistent seen set (O(M) per
327/// iteration) when supported, with a legacy `HashSet<Vec<ScalarKey>>` fallback.
328pub struct FixpointState {
329    rule_name: String,
330    facts: Vec<RecordBatch>,
331    delta: Vec<RecordBatch>,
332    schema: SchemaRef,
333    key_column_indices: Vec<usize>,
334    /// All column indices for full-row dedup (legacy path only).
335    all_column_indices: Vec<usize>,
336    /// Running total of facts bytes for memory limit tracking.
337    facts_bytes: usize,
338    /// Maximum bytes allowed for this derived relation.
339    max_derived_bytes: usize,
340    /// Optional monotonic aggregate tracking.
341    monotonic_agg: Option<MonotonicAggState>,
342    /// Arrow RowConverter-based dedup state; `None` triggers legacy fallback.
343    row_dedup: Option<RowDedupState>,
344}
345
346impl FixpointState {
347    /// Create a new fixpoint state for a rule.
348    pub fn new(
349        rule_name: String,
350        schema: SchemaRef,
351        key_column_indices: Vec<usize>,
352        max_derived_bytes: usize,
353        monotonic_agg: Option<MonotonicAggState>,
354    ) -> Self {
355        let num_cols = schema.fields().len();
356        let row_dedup = RowDedupState::try_new(&schema);
357        Self {
358            rule_name,
359            facts: Vec::new(),
360            delta: Vec::new(),
361            schema,
362            key_column_indices,
363            all_column_indices: (0..num_cols).collect(),
364            facts_bytes: 0,
365            max_derived_bytes,
366            monotonic_agg,
367            row_dedup,
368        }
369    }
370
371    /// Merge candidate rows into facts, computing delta (truly new rows).
372    ///
373    /// Returns `true` if any new facts were added.
374    pub async fn merge_delta(
375        &mut self,
376        candidates: Vec<RecordBatch>,
377        task_ctx: Option<Arc<TaskContext>>,
378    ) -> DFResult<bool> {
379        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
380            self.delta.clear();
381            return Ok(false);
382        }
383
384        // Round floats for stable dedup
385        let candidates = round_float_columns(&candidates);
386
387        // Compute delta: rows in candidates not already in facts
388        let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
389
390        if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
391            self.delta.clear();
392            // Update monotonic aggs even with empty delta (for stability check)
393            if let Some(ref mut agg) = self.monotonic_agg {
394                agg.snapshot();
395            }
396            return Ok(false);
397        }
398
399        // Check memory limit
400        let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
401        if self.facts_bytes + delta_bytes > self.max_derived_bytes {
402            return Err(datafusion::error::DataFusionError::Execution(
403                LocyRuntimeError::MemoryLimitExceeded {
404                    rule: self.rule_name.clone(),
405                    bytes: self.facts_bytes + delta_bytes,
406                    limit: self.max_derived_bytes,
407                }
408                .to_string(),
409            ));
410        }
411
412        // Update monotonic aggs
413        if let Some(ref mut agg) = self.monotonic_agg {
414            agg.snapshot();
415            agg.update(&self.key_column_indices, &delta);
416        }
417
418        // Append delta to facts
419        self.facts_bytes += delta_bytes;
420        self.facts.extend(delta.iter().cloned());
421        self.delta = delta;
422
423        Ok(true)
424    }
425
426    /// Dispatch to vectorized LeftAntiJoin, Arrow RowConverter dedup, or legacy ScalarKey dedup.
427    ///
428    /// Priority order:
429    /// 1. `arrow_left_anti_dedup` when `total_existing >= DEDUP_ANTI_JOIN_THRESHOLD` and task_ctx available.
430    /// 2. `RowDedupState` (persistent HashSet, O(M) per iteration) when schema is supported.
431    /// 3. `compute_delta_legacy` (rebuilds from facts, fallback for unsupported column types).
432    async fn compute_delta(
433        &mut self,
434        candidates: &[RecordBatch],
435        task_ctx: Option<&Arc<TaskContext>>,
436    ) -> DFResult<Vec<RecordBatch>> {
437        let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
438        if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
439            && let Some(ctx) = task_ctx
440        {
441            return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
442                .await;
443        }
444        if let Some(ref mut rd) = self.row_dedup {
445            rd.compute_delta(candidates, &self.schema)
446        } else {
447            self.compute_delta_legacy(candidates)
448        }
449    }
450
451    /// Legacy dedup: rebuild a `HashSet<Vec<ScalarKey>>` from all facts each call.
452    ///
453    /// Used as fallback when `RowConverter` does not support the schema's column types.
454    fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
455        // Build set of existing fact row keys (ALL columns)
456        let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
457        for batch in &self.facts {
458            for row_idx in 0..batch.num_rows() {
459                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
460                existing.insert(key);
461            }
462        }
463
464        let mut delta_batches = Vec::new();
465        for batch in candidates {
466            if batch.num_rows() == 0 {
467                continue;
468            }
469            // Filter to only new rows
470            let mut keep = Vec::with_capacity(batch.num_rows());
471            for row_idx in 0..batch.num_rows() {
472                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
473                keep.push(!existing.contains(&key));
474            }
475
476            // Also dedup within the candidate batch itself
477            for (row_idx, kept) in keep.iter_mut().enumerate() {
478                if *kept {
479                    let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
480                    if !existing.insert(key) {
481                        *kept = false;
482                    }
483                }
484            }
485
486            let keep_mask = arrow_array::BooleanArray::from(keep);
487            let new_rows = batch
488                .columns()
489                .iter()
490                .map(|col| {
491                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
492                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
493                    })
494                })
495                .collect::<DFResult<Vec<_>>>()?;
496
497            if new_rows.first().is_some_and(|c| !c.is_empty()) {
498                let filtered =
499                    RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
500                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
501                    })?;
502                delta_batches.push(filtered);
503            }
504        }
505
506        Ok(delta_batches)
507    }
508
509    /// Check if this rule has converged (no new facts and aggs stable).
510    pub fn is_converged(&self) -> bool {
511        let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
512        let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
513        delta_empty && agg_stable
514    }
515
516    /// Get all accumulated facts.
517    pub fn all_facts(&self) -> &[RecordBatch] {
518        &self.facts
519    }
520
521    /// Get the delta from the latest iteration.
522    pub fn all_delta(&self) -> &[RecordBatch] {
523        &self.delta
524    }
525
526    /// Consume self and return facts.
527    pub fn into_facts(self) -> Vec<RecordBatch> {
528        self.facts
529    }
530}
531
532/// Estimate byte size of a RecordBatch.
533fn batch_byte_size(batch: &RecordBatch) -> usize {
534    batch
535        .columns()
536        .iter()
537        .map(|col| col.get_buffer_memory_size())
538        .sum()
539}
540
541// ---------------------------------------------------------------------------
542// Float rounding for stable dedup
543// ---------------------------------------------------------------------------
544
545/// Round all Float64 columns to 12 decimal places for stable dedup.
546fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
547    batches
548        .iter()
549        .map(|batch| {
550            let schema = batch.schema();
551            let has_float = schema
552                .fields()
553                .iter()
554                .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
555            if !has_float {
556                return batch.clone();
557            }
558
559            let columns: Vec<arrow_array::ArrayRef> = batch
560                .columns()
561                .iter()
562                .enumerate()
563                .map(|(i, col)| {
564                    if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
565                        let arr = col
566                            .as_any()
567                            .downcast_ref::<arrow_array::Float64Array>()
568                            .unwrap();
569                        let rounded: arrow_array::Float64Array = arr
570                            .iter()
571                            .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
572                            .collect();
573                        Arc::new(rounded) as arrow_array::ArrayRef
574                    } else {
575                        Arc::clone(col)
576                    }
577                })
578                .collect();
579
580            RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
581        })
582        .collect()
583}
584
585// ---------------------------------------------------------------------------
586// LeftAntiJoin delta deduplication
587// ---------------------------------------------------------------------------
588
589/// Row threshold above which the vectorized Arrow LeftAntiJoin dedup path is used.
590///
591/// Below this threshold the persistent `RowDedupState` HashSet is O(M) and
592/// avoids rebuilding the existing-row set; above it DataFusion's vectorized
593/// HashJoinExec is more cache-efficient.
594const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
595
596/// Deduplicate `candidates` against `existing` using DataFusion's HashJoinExec.
597///
598/// Returns rows in `candidates` that do not appear in `existing` (LeftAnti semantics).
599/// `null_equals_null = true` so NULLs are treated as equal for dedup purposes.
600async fn arrow_left_anti_dedup(
601    candidates: Vec<RecordBatch>,
602    existing: &[RecordBatch],
603    schema: &SchemaRef,
604    task_ctx: &Arc<TaskContext>,
605) -> DFResult<Vec<RecordBatch>> {
606    if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
607        return Ok(candidates);
608    }
609
610    let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
611    let right: Arc<dyn ExecutionPlan> =
612        Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
613
614    let on: Vec<(
615        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
616        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
617    )> = schema
618        .fields()
619        .iter()
620        .enumerate()
621        .map(|(i, field)| {
622            let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
623                datafusion::physical_plan::expressions::Column::new(field.name(), i),
624            );
625            let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
626                datafusion::physical_plan::expressions::Column::new(field.name(), i),
627            );
628            (l, r)
629        })
630        .collect();
631
632    if on.is_empty() {
633        return Ok(vec![]);
634    }
635
636    let join = HashJoinExec::try_new(
637        left,
638        right,
639        on,
640        None,
641        &JoinType::LeftAnti,
642        None,
643        PartitionMode::CollectLeft,
644        datafusion::common::NullEquality::NullEqualsNull,
645    )?;
646
647    let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
648    collect_all_partitions(&join_arc, task_ctx.clone()).await
649}
650
651// ---------------------------------------------------------------------------
652// Plan types for fixpoint rules
653// ---------------------------------------------------------------------------
654
655/// IS-ref binding: a reference from a clause body to a derived relation.
656#[derive(Debug, Clone)]
657pub struct IsRefBinding {
658    /// Index into the DerivedScanRegistry.
659    pub derived_scan_index: usize,
660    /// Name of the rule being referenced.
661    pub rule_name: String,
662    /// Whether this is a self-reference (rule references itself).
663    pub is_self_ref: bool,
664    /// Whether this is a negated reference (NOT IS).
665    pub negated: bool,
666    /// For negated IS-refs: `(left_body_col, right_derived_col)` pairs for anti-join filtering.
667    ///
668    /// `left_body_col` is the VID column in the clause body (e.g., `"n._vid"`);
669    /// `right_derived_col` is the corresponding KEY column in the negated rule's facts (e.g., `"n"`).
670    /// Empty for non-negated IS-refs.
671    pub anti_join_cols: Vec<(String, String)>,
672}
673
674/// A single clause (body) within a fixpoint rule.
675#[derive(Debug)]
676pub struct FixpointClausePlan {
677    /// The logical plan for the clause body.
678    pub body_logical: LogicalPlan,
679    /// IS-ref bindings used by this clause.
680    pub is_ref_bindings: Vec<IsRefBinding>,
681    /// Priority value for this clause (if PRIORITY semantics apply).
682    pub priority: Option<i64>,
683}
684
685/// Physical plan for a single rule in a fixpoint stratum.
686#[derive(Debug)]
687pub struct FixpointRulePlan {
688    /// Rule name.
689    pub name: String,
690    /// Clause bodies (each evaluates to candidate rows).
691    pub clauses: Vec<FixpointClausePlan>,
692    /// Output schema for this rule's derived relation.
693    pub yield_schema: SchemaRef,
694    /// Indices of KEY columns within yield_schema.
695    pub key_column_indices: Vec<usize>,
696    /// Priority value (if PRIORITY semantics apply).
697    pub priority: Option<i64>,
698    /// Whether this rule has FOLD semantics.
699    pub has_fold: bool,
700    /// FOLD bindings for post-fixpoint aggregation.
701    pub fold_bindings: Vec<FoldBinding>,
702    /// Whether this rule has BEST BY semantics.
703    pub has_best_by: bool,
704    /// BEST BY sort criteria for post-fixpoint selection.
705    pub best_by_criteria: Vec<SortCriterion>,
706    /// Whether this rule has PRIORITY semantics.
707    pub has_priority: bool,
708    /// Whether BEST BY should apply a deterministic secondary sort for
709    /// tie-breaking. When false, tied rows are selected non-deterministically
710    /// (faster but not repeatable across runs).
711    pub deterministic: bool,
712}
713
714// ---------------------------------------------------------------------------
715// run_fixpoint_loop — the core semi-naive iteration algorithm
716// ---------------------------------------------------------------------------
717
718/// Run the semi-naive fixpoint iteration loop.
719///
720/// Evaluates all rules in a stratum repeatedly, feeding deltas back through
721/// derived scan handles until convergence or limits are reached.
722#[allow(clippy::too_many_arguments)]
723async fn run_fixpoint_loop(
724    rules: Vec<FixpointRulePlan>,
725    max_iterations: usize,
726    timeout: Duration,
727    graph_ctx: Arc<GraphExecutionContext>,
728    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
729    storage: Arc<StorageManager>,
730    schema_info: Arc<UniSchema>,
731    params: HashMap<String, Value>,
732    registry: Arc<DerivedScanRegistry>,
733    output_schema: SchemaRef,
734    max_derived_bytes: usize,
735    derivation_tracker: Option<Arc<DerivationTracker>>,
736    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
737) -> DFResult<Vec<RecordBatch>> {
738    let start = Instant::now();
739    let task_ctx = session_ctx.read().task_ctx();
740
741    // Initialize per-rule state
742    let mut states: Vec<FixpointState> = rules
743        .iter()
744        .map(|rule| {
745            let monotonic_agg = if !rule.fold_bindings.is_empty() {
746                let bindings: Vec<MonotonicFoldBinding> = rule
747                    .fold_bindings
748                    .iter()
749                    .map(|fb| MonotonicFoldBinding {
750                        fold_name: fb.output_name.clone(),
751                        kind: fb.kind.clone(),
752                        input_col_index: fb.input_col_index,
753                    })
754                    .collect();
755                Some(MonotonicAggState::new(bindings))
756            } else {
757                None
758            };
759            FixpointState::new(
760                rule.name.clone(),
761                Arc::clone(&rule.yield_schema),
762                rule.key_column_indices.clone(),
763                max_derived_bytes,
764                monotonic_agg,
765            )
766        })
767        .collect();
768
769    // Main iteration loop
770    let mut converged = false;
771    let mut total_iters = 0usize;
772    for iteration in 0..max_iterations {
773        total_iters = iteration + 1;
774        tracing::debug!("fixpoint iteration {}", iteration);
775        let mut any_changed = false;
776
777        for rule_idx in 0..rules.len() {
778            let rule = &rules[rule_idx];
779
780            // Update derived scan handles for this rule's clauses
781            update_derived_scan_handles(&registry, &states, rule_idx, &rules);
782
783            // Evaluate clause bodies, tracking per-clause candidates for provenance.
784            let mut all_candidates = Vec::new();
785            let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
786            for clause in &rule.clauses {
787                let mut batches = execute_subplan(
788                    &clause.body_logical,
789                    &params,
790                    &HashMap::new(),
791                    &graph_ctx,
792                    &session_ctx,
793                    &storage,
794                    &schema_info,
795                )
796                .await?;
797                // Apply anti-joins for negated IS-refs (IS NOT semantics).
798                for binding in &clause.is_ref_bindings {
799                    if binding.negated
800                        && !binding.anti_join_cols.is_empty()
801                        && let Some(entry) = registry.get(binding.derived_scan_index)
802                    {
803                        let neg_facts = entry.data.read().clone();
804                        if !neg_facts.is_empty() {
805                            for (left_col, right_col) in &binding.anti_join_cols {
806                                batches =
807                                    apply_anti_join(batches, &neg_facts, left_col, right_col)?;
808                            }
809                        }
810                    }
811                }
812                clause_candidates.push(batches.clone());
813                all_candidates.extend(batches);
814            }
815
816            // Merge delta
817            let changed = states[rule_idx]
818                .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
819                .await?;
820            if changed {
821                any_changed = true;
822                // Record provenance for newly derived facts when tracker is present.
823                if let Some(ref tracker) = derivation_tracker {
824                    record_provenance(
825                        tracker,
826                        rule,
827                        &states[rule_idx],
828                        &clause_candidates,
829                        iteration,
830                    );
831                }
832            }
833        }
834
835        // Check convergence
836        if !any_changed && states.iter().all(|s| s.is_converged()) {
837            tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
838            converged = true;
839            break;
840        }
841
842        // Check timeout
843        if start.elapsed() > timeout {
844            return Err(datafusion::error::DataFusionError::Execution(
845                LocyRuntimeError::NonConvergence {
846                    iterations: iteration + 1,
847                }
848                .to_string(),
849            ));
850        }
851    }
852
853    // Write per-rule iteration counts to the shared slot.
854    if let Ok(mut counts) = iteration_counts.write() {
855        for rule in &rules {
856            counts.insert(rule.name.clone(), total_iters);
857        }
858    }
859
860    // If we exhausted all iterations without converging, return a non-convergence error.
861    if !converged {
862        return Err(datafusion::error::DataFusionError::Execution(
863            LocyRuntimeError::NonConvergence {
864                iterations: max_iterations,
865            }
866            .to_string(),
867        ));
868    }
869
870    // Post-fixpoint processing per rule and collect output
871    let task_ctx = session_ctx.read().task_ctx();
872    let mut all_output = Vec::new();
873
874    for (rule_idx, state) in states.into_iter().enumerate() {
875        let rule = &rules[rule_idx];
876        let facts = state.into_facts();
877        if facts.is_empty() {
878            continue;
879        }
880
881        let processed = apply_post_fixpoint_chain(facts, rule, &task_ctx).await?;
882        all_output.extend(processed);
883    }
884
885    // If no output, return empty batch with output schema
886    if all_output.is_empty() {
887        all_output.push(RecordBatch::new_empty(output_schema));
888    }
889
890    Ok(all_output)
891}
892
893// ---------------------------------------------------------------------------
894// Provenance recording helpers
895// ---------------------------------------------------------------------------
896
897/// Record provenance for all newly derived facts (rows in the current delta).
898///
899/// Called after `merge_delta` returns `true`. Attributes each new fact to the
900/// clause most likely to have produced it, using first-derivation-wins semantics.
901fn record_provenance(
902    tracker: &Arc<DerivationTracker>,
903    rule: &FixpointRulePlan,
904    state: &FixpointState,
905    clause_candidates: &[Vec<RecordBatch>],
906    iteration: usize,
907) {
908    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
909
910    for delta_batch in state.all_delta() {
911        for row_idx in 0..delta_batch.num_rows() {
912            let row_hash = format!(
913                "{:?}",
914                extract_scalar_key(delta_batch, &all_indices, row_idx)
915            )
916            .into_bytes();
917            let fact_row = batch_row_to_value_map(delta_batch, row_idx);
918            let clause_index =
919                find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
920
921            let entry = DerivationEntry {
922                rule_name: rule.name.clone(),
923                clause_index,
924                inputs: vec![],
925                along_values: std::collections::HashMap::new(),
926                iteration,
927                fact_row,
928            };
929            tracker.record(row_hash, entry);
930        }
931    }
932}
933
934/// Determine which clause produced a given row by checking each clause's candidates.
935///
936/// Returns the index of the first clause whose candidates contain a matching row.
937/// Falls back to 0 if no match is found.
938fn find_clause_for_row(
939    delta_batch: &RecordBatch,
940    row_idx: usize,
941    all_indices: &[usize],
942    clause_candidates: &[Vec<RecordBatch>],
943) -> usize {
944    let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
945    for (clause_idx, batches) in clause_candidates.iter().enumerate() {
946        for batch in batches {
947            if batch.num_columns() != all_indices.len() {
948                continue;
949            }
950            for r in 0..batch.num_rows() {
951                if extract_scalar_key(batch, all_indices, r) == target_key {
952                    return clause_idx;
953                }
954            }
955        }
956    }
957    0
958}
959
960/// Convert a single row from a `RecordBatch` at `row_idx` into a `HashMap<String, Value>`.
961fn batch_row_to_value_map(
962    batch: &RecordBatch,
963    row_idx: usize,
964) -> std::collections::HashMap<String, Value> {
965    use uni_store::storage::arrow_convert::arrow_to_value;
966
967    let schema = batch.schema();
968    schema
969        .fields()
970        .iter()
971        .enumerate()
972        .map(|(col_idx, field)| {
973            let col = batch.column(col_idx);
974            let val = arrow_to_value(col.as_ref(), row_idx, None);
975            (field.name().clone(), val)
976        })
977        .collect()
978}
979
980/// Filter `batches` to exclude rows where `left_col` VID appears in `neg_facts[right_col]`.
981///
982/// Implements anti-join semantics for negated IS-refs (`n IS NOT rule`): keeps only
983/// rows whose subject VID is NOT present in the negated rule's fully-converged facts.
984pub fn apply_anti_join(
985    batches: Vec<RecordBatch>,
986    neg_facts: &[RecordBatch],
987    left_col: &str,
988    right_col: &str,
989) -> datafusion::error::Result<Vec<RecordBatch>> {
990    use arrow::compute::filter_record_batch;
991    use arrow_array::{Array as _, BooleanArray, UInt64Array};
992
993    // Collect right-side VIDs from the negated rule's derived facts.
994    let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
995    for batch in neg_facts {
996        let Ok(idx) = batch.schema().index_of(right_col) else {
997            continue;
998        };
999        let arr = batch.column(idx);
1000        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
1001            continue;
1002        };
1003        for i in 0..vids.len() {
1004            if !vids.is_null(i) {
1005                banned.insert(vids.value(i));
1006            }
1007        }
1008    }
1009
1010    if banned.is_empty() {
1011        return Ok(batches);
1012    }
1013
1014    // Filter body batches: keep rows where left_col NOT IN banned.
1015    let mut result = Vec::new();
1016    for batch in batches {
1017        let Ok(idx) = batch.schema().index_of(left_col) else {
1018            result.push(batch);
1019            continue;
1020        };
1021        let arr = batch.column(idx);
1022        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
1023            result.push(batch);
1024            continue;
1025        };
1026        let keep: Vec<bool> = (0..vids.len())
1027            .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
1028            .collect();
1029        let keep_arr = BooleanArray::from(keep);
1030        let filtered = filter_record_batch(&batch, &keep_arr)
1031            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1032        if filtered.num_rows() > 0 {
1033            result.push(filtered);
1034        }
1035    }
1036    Ok(result)
1037}
1038
1039/// Update derived scan handles before evaluating a rule's clause bodies.
1040///
1041/// For self-references: inject delta (semi-naive optimization).
1042/// For cross-references: inject full facts.
1043fn update_derived_scan_handles(
1044    registry: &DerivedScanRegistry,
1045    states: &[FixpointState],
1046    current_rule_idx: usize,
1047    rules: &[FixpointRulePlan],
1048) {
1049    let current_rule_name = &rules[current_rule_idx].name;
1050
1051    for entry in &registry.entries {
1052        // Find the state for this entry's rule
1053        let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
1054        let Some(source_idx) = source_state_idx else {
1055            continue;
1056        };
1057
1058        let is_self = entry.rule_name == *current_rule_name;
1059        let data = if is_self {
1060            // Self-ref: inject delta for semi-naive
1061            states[source_idx].all_delta().to_vec()
1062        } else {
1063            // Cross-ref: inject full facts
1064            states[source_idx].all_facts().to_vec()
1065        };
1066
1067        // If empty, write an empty batch so the scan returns zero rows
1068        let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
1069            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
1070        } else {
1071            data
1072        };
1073
1074        let mut guard = entry.data.write();
1075        *guard = data;
1076    }
1077}
1078
1079// ---------------------------------------------------------------------------
1080// DerivedScanExec — physical plan that reads from shared data at execution time
1081// ---------------------------------------------------------------------------
1082
1083/// Physical plan for `LocyDerivedScan` that reads from a shared `Arc<RwLock>` at
1084/// execution time (not at plan creation time).
1085///
1086/// This is critical for fixpoint iteration: the data handle is updated between
1087/// iterations, and each re-execution of the subplan must read the latest data.
1088pub struct DerivedScanExec {
1089    data: Arc<RwLock<Vec<RecordBatch>>>,
1090    schema: SchemaRef,
1091    properties: PlanProperties,
1092}
1093
1094impl DerivedScanExec {
1095    pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
1096        let properties = compute_plan_properties(Arc::clone(&schema));
1097        Self {
1098            data,
1099            schema,
1100            properties,
1101        }
1102    }
1103}
1104
1105impl fmt::Debug for DerivedScanExec {
1106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1107        f.debug_struct("DerivedScanExec")
1108            .field("schema", &self.schema)
1109            .finish()
1110    }
1111}
1112
1113impl DisplayAs for DerivedScanExec {
1114    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1115        write!(f, "DerivedScanExec")
1116    }
1117}
1118
1119impl ExecutionPlan for DerivedScanExec {
1120    fn name(&self) -> &str {
1121        "DerivedScanExec"
1122    }
1123    fn as_any(&self) -> &dyn Any {
1124        self
1125    }
1126    fn schema(&self) -> SchemaRef {
1127        Arc::clone(&self.schema)
1128    }
1129    fn properties(&self) -> &PlanProperties {
1130        &self.properties
1131    }
1132    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1133        vec![]
1134    }
1135    fn with_new_children(
1136        self: Arc<Self>,
1137        _children: Vec<Arc<dyn ExecutionPlan>>,
1138    ) -> DFResult<Arc<dyn ExecutionPlan>> {
1139        Ok(self)
1140    }
1141    fn execute(
1142        &self,
1143        _partition: usize,
1144        _context: Arc<TaskContext>,
1145    ) -> DFResult<SendableRecordBatchStream> {
1146        let batches = {
1147            let guard = self.data.read();
1148            if guard.is_empty() {
1149                vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
1150            } else {
1151                guard.clone()
1152            }
1153        };
1154        Ok(Box::pin(MemoryStream::try_new(
1155            batches,
1156            Arc::clone(&self.schema),
1157            None,
1158        )?))
1159    }
1160}
1161
1162// ---------------------------------------------------------------------------
1163// InMemoryExec — wrapper to feed Vec<RecordBatch> into operator chains
1164// ---------------------------------------------------------------------------
1165
1166/// Simple in-memory execution plan that serves pre-computed batches.
1167///
1168/// Used internally to feed fixpoint results into post-fixpoint operator chains
1169/// (FOLD, BEST BY). Not exported — only used within this module.
1170struct InMemoryExec {
1171    batches: Vec<RecordBatch>,
1172    schema: SchemaRef,
1173    properties: PlanProperties,
1174}
1175
1176impl InMemoryExec {
1177    fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
1178        let properties = compute_plan_properties(Arc::clone(&schema));
1179        Self {
1180            batches,
1181            schema,
1182            properties,
1183        }
1184    }
1185}
1186
1187impl fmt::Debug for InMemoryExec {
1188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1189        f.debug_struct("InMemoryExec")
1190            .field("num_batches", &self.batches.len())
1191            .field("schema", &self.schema)
1192            .finish()
1193    }
1194}
1195
1196impl DisplayAs for InMemoryExec {
1197    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1198        write!(f, "InMemoryExec: batches={}", self.batches.len())
1199    }
1200}
1201
1202impl ExecutionPlan for InMemoryExec {
1203    fn name(&self) -> &str {
1204        "InMemoryExec"
1205    }
1206    fn as_any(&self) -> &dyn Any {
1207        self
1208    }
1209    fn schema(&self) -> SchemaRef {
1210        Arc::clone(&self.schema)
1211    }
1212    fn properties(&self) -> &PlanProperties {
1213        &self.properties
1214    }
1215    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1216        vec![]
1217    }
1218    fn with_new_children(
1219        self: Arc<Self>,
1220        _children: Vec<Arc<dyn ExecutionPlan>>,
1221    ) -> DFResult<Arc<dyn ExecutionPlan>> {
1222        Ok(self)
1223    }
1224    fn execute(
1225        &self,
1226        _partition: usize,
1227        _context: Arc<TaskContext>,
1228    ) -> DFResult<SendableRecordBatchStream> {
1229        Ok(Box::pin(MemoryStream::try_new(
1230            self.batches.clone(),
1231            Arc::clone(&self.schema),
1232            None,
1233        )?))
1234    }
1235}
1236
1237// ---------------------------------------------------------------------------
1238// Post-fixpoint chain — FOLD and BEST BY on converged facts
1239// ---------------------------------------------------------------------------
1240
1241/// Apply post-fixpoint operators (FOLD, BEST BY, PRIORITY) to converged facts.
1242pub(crate) async fn apply_post_fixpoint_chain(
1243    facts: Vec<RecordBatch>,
1244    rule: &FixpointRulePlan,
1245    task_ctx: &Arc<TaskContext>,
1246) -> DFResult<Vec<RecordBatch>> {
1247    if !rule.has_fold && !rule.has_best_by && !rule.has_priority {
1248        return Ok(facts);
1249    }
1250
1251    // Wrap facts in InMemoryExec
1252    let schema = Arc::clone(&rule.yield_schema);
1253    let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema));
1254
1255    // Apply PRIORITY first — keeps only rows with max __priority per KEY group,
1256    // then strips the __priority column from output.
1257    // Must run before FOLD so that the __priority column is still present.
1258    let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
1259        let priority_schema = input.schema();
1260        let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
1261            datafusion::common::DataFusionError::Internal(
1262                "PRIORITY rule missing __priority column".to_string(),
1263            )
1264        })?;
1265        Arc::new(PriorityExec::new(
1266            input,
1267            rule.key_column_indices.clone(),
1268            priority_idx,
1269        ))
1270    } else {
1271        input
1272    };
1273
1274    // Apply FOLD
1275    let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
1276        Arc::new(FoldExec::new(
1277            current,
1278            rule.key_column_indices.clone(),
1279            rule.fold_bindings.clone(),
1280        ))
1281    } else {
1282        current
1283    };
1284
1285    // Apply BEST BY
1286    let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1287        Arc::new(BestByExec::new(
1288            current,
1289            rule.key_column_indices.clone(),
1290            rule.best_by_criteria.clone(),
1291            rule.deterministic,
1292        ))
1293    } else {
1294        current
1295    };
1296
1297    collect_all_partitions(&current, Arc::clone(task_ctx)).await
1298}
1299
1300// ---------------------------------------------------------------------------
1301// FixpointExec — DataFusion ExecutionPlan
1302// ---------------------------------------------------------------------------
1303
1304/// DataFusion `ExecutionPlan` that drives semi-naive fixpoint iteration.
1305///
1306/// Has no physical children: clause bodies are re-planned from logical plans
1307/// on each iteration (same pattern as `RecursiveCTEExec` and `GraphApplyExec`).
1308pub struct FixpointExec {
1309    rules: Vec<FixpointRulePlan>,
1310    max_iterations: usize,
1311    timeout: Duration,
1312    graph_ctx: Arc<GraphExecutionContext>,
1313    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1314    storage: Arc<StorageManager>,
1315    schema_info: Arc<UniSchema>,
1316    params: HashMap<String, Value>,
1317    derived_scan_registry: Arc<DerivedScanRegistry>,
1318    output_schema: SchemaRef,
1319    properties: PlanProperties,
1320    metrics: ExecutionPlanMetricsSet,
1321    max_derived_bytes: usize,
1322    /// Optional provenance tracker populated during fixpoint iteration.
1323    derivation_tracker: Option<Arc<DerivationTracker>>,
1324    /// Shared slot written with per-rule iteration counts after convergence.
1325    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1326}
1327
1328impl fmt::Debug for FixpointExec {
1329    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1330        f.debug_struct("FixpointExec")
1331            .field("rules_count", &self.rules.len())
1332            .field("max_iterations", &self.max_iterations)
1333            .field("timeout", &self.timeout)
1334            .field("output_schema", &self.output_schema)
1335            .field("max_derived_bytes", &self.max_derived_bytes)
1336            .finish_non_exhaustive()
1337    }
1338}
1339
1340impl FixpointExec {
1341    /// Create a new `FixpointExec`.
1342    #[allow(clippy::too_many_arguments)]
1343    pub fn new(
1344        rules: Vec<FixpointRulePlan>,
1345        max_iterations: usize,
1346        timeout: Duration,
1347        graph_ctx: Arc<GraphExecutionContext>,
1348        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1349        storage: Arc<StorageManager>,
1350        schema_info: Arc<UniSchema>,
1351        params: HashMap<String, Value>,
1352        derived_scan_registry: Arc<DerivedScanRegistry>,
1353        output_schema: SchemaRef,
1354        max_derived_bytes: usize,
1355        derivation_tracker: Option<Arc<DerivationTracker>>,
1356        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1357    ) -> Self {
1358        let properties = compute_plan_properties(Arc::clone(&output_schema));
1359        Self {
1360            rules,
1361            max_iterations,
1362            timeout,
1363            graph_ctx,
1364            session_ctx,
1365            storage,
1366            schema_info,
1367            params,
1368            derived_scan_registry,
1369            output_schema,
1370            properties,
1371            metrics: ExecutionPlanMetricsSet::new(),
1372            max_derived_bytes,
1373            derivation_tracker,
1374            iteration_counts,
1375        }
1376    }
1377
1378    /// Returns the shared iteration counts slot for post-execution inspection.
1379    pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
1380        Arc::clone(&self.iteration_counts)
1381    }
1382}
1383
1384impl DisplayAs for FixpointExec {
1385    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1386        write!(
1387            f,
1388            "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
1389            self.rules
1390                .iter()
1391                .map(|r| r.name.as_str())
1392                .collect::<Vec<_>>()
1393                .join(", "),
1394            self.max_iterations,
1395            self.timeout,
1396        )
1397    }
1398}
1399
1400impl ExecutionPlan for FixpointExec {
1401    fn name(&self) -> &str {
1402        "FixpointExec"
1403    }
1404
1405    fn as_any(&self) -> &dyn Any {
1406        self
1407    }
1408
1409    fn schema(&self) -> SchemaRef {
1410        Arc::clone(&self.output_schema)
1411    }
1412
1413    fn properties(&self) -> &PlanProperties {
1414        &self.properties
1415    }
1416
1417    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1418        // No physical children — clause bodies are re-planned each iteration
1419        vec![]
1420    }
1421
1422    fn with_new_children(
1423        self: Arc<Self>,
1424        children: Vec<Arc<dyn ExecutionPlan>>,
1425    ) -> DFResult<Arc<dyn ExecutionPlan>> {
1426        if !children.is_empty() {
1427            return Err(datafusion::error::DataFusionError::Plan(
1428                "FixpointExec has no children".to_string(),
1429            ));
1430        }
1431        Ok(self)
1432    }
1433
1434    fn execute(
1435        &self,
1436        partition: usize,
1437        _context: Arc<TaskContext>,
1438    ) -> DFResult<SendableRecordBatchStream> {
1439        let metrics = BaselineMetrics::new(&self.metrics, partition);
1440
1441        // Clone all fields for the async closure
1442        let rules = self
1443            .rules
1444            .iter()
1445            .map(|r| {
1446                // We need to clone the FixpointRulePlan, but it contains LogicalPlan
1447                // which doesn't implement Clone traditionally. However, our LogicalPlan
1448                // does implement Clone since it's an enum.
1449                FixpointRulePlan {
1450                    name: r.name.clone(),
1451                    clauses: r
1452                        .clauses
1453                        .iter()
1454                        .map(|c| FixpointClausePlan {
1455                            body_logical: c.body_logical.clone(),
1456                            is_ref_bindings: c.is_ref_bindings.clone(),
1457                            priority: c.priority,
1458                        })
1459                        .collect(),
1460                    yield_schema: Arc::clone(&r.yield_schema),
1461                    key_column_indices: r.key_column_indices.clone(),
1462                    priority: r.priority,
1463                    has_fold: r.has_fold,
1464                    fold_bindings: r.fold_bindings.clone(),
1465                    has_best_by: r.has_best_by,
1466                    best_by_criteria: r.best_by_criteria.clone(),
1467                    has_priority: r.has_priority,
1468                    deterministic: r.deterministic,
1469                }
1470            })
1471            .collect();
1472
1473        let max_iterations = self.max_iterations;
1474        let timeout = self.timeout;
1475        let graph_ctx = Arc::clone(&self.graph_ctx);
1476        let session_ctx = Arc::clone(&self.session_ctx);
1477        let storage = Arc::clone(&self.storage);
1478        let schema_info = Arc::clone(&self.schema_info);
1479        let params = self.params.clone();
1480        let registry = Arc::clone(&self.derived_scan_registry);
1481        let output_schema = Arc::clone(&self.output_schema);
1482        let max_derived_bytes = self.max_derived_bytes;
1483        let derivation_tracker = self.derivation_tracker.clone();
1484        let iteration_counts = Arc::clone(&self.iteration_counts);
1485
1486        let fut = async move {
1487            run_fixpoint_loop(
1488                rules,
1489                max_iterations,
1490                timeout,
1491                graph_ctx,
1492                session_ctx,
1493                storage,
1494                schema_info,
1495                params,
1496                registry,
1497                output_schema,
1498                max_derived_bytes,
1499                derivation_tracker,
1500                iteration_counts,
1501            )
1502            .await
1503        };
1504
1505        Ok(Box::pin(FixpointStream {
1506            state: FixpointStreamState::Running(Box::pin(fut)),
1507            schema: Arc::clone(&self.output_schema),
1508            metrics,
1509        }))
1510    }
1511
1512    fn metrics(&self) -> Option<MetricsSet> {
1513        Some(self.metrics.clone_inner())
1514    }
1515}
1516
1517// ---------------------------------------------------------------------------
1518// FixpointStream — async state machine for streaming results
1519// ---------------------------------------------------------------------------
1520
1521enum FixpointStreamState {
1522    /// Fixpoint loop is running.
1523    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
1524    /// Emitting accumulated result batches one at a time.
1525    Emitting(Vec<RecordBatch>, usize),
1526    /// All batches emitted.
1527    Done,
1528}
1529
1530struct FixpointStream {
1531    state: FixpointStreamState,
1532    schema: SchemaRef,
1533    metrics: BaselineMetrics,
1534}
1535
1536impl Stream for FixpointStream {
1537    type Item = DFResult<RecordBatch>;
1538
1539    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1540        let this = self.get_mut();
1541        loop {
1542            match &mut this.state {
1543                FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
1544                    Poll::Ready(Ok(batches)) => {
1545                        if batches.is_empty() {
1546                            this.state = FixpointStreamState::Done;
1547                            return Poll::Ready(None);
1548                        }
1549                        this.state = FixpointStreamState::Emitting(batches, 0);
1550                        // Loop to emit first batch
1551                    }
1552                    Poll::Ready(Err(e)) => {
1553                        this.state = FixpointStreamState::Done;
1554                        return Poll::Ready(Some(Err(e)));
1555                    }
1556                    Poll::Pending => return Poll::Pending,
1557                },
1558                FixpointStreamState::Emitting(batches, idx) => {
1559                    if *idx >= batches.len() {
1560                        this.state = FixpointStreamState::Done;
1561                        return Poll::Ready(None);
1562                    }
1563                    let batch = batches[*idx].clone();
1564                    *idx += 1;
1565                    this.metrics.record_output(batch.num_rows());
1566                    return Poll::Ready(Some(Ok(batch)));
1567                }
1568                FixpointStreamState::Done => return Poll::Ready(None),
1569            }
1570        }
1571    }
1572}
1573
1574impl RecordBatchStream for FixpointStream {
1575    fn schema(&self) -> SchemaRef {
1576        Arc::clone(&self.schema)
1577    }
1578}
1579
1580// ---------------------------------------------------------------------------
1581// Unit tests
1582// ---------------------------------------------------------------------------
1583
1584#[cfg(test)]
1585mod tests {
1586    use super::*;
1587    use arrow_array::{Float64Array, Int64Array, StringArray};
1588    use arrow_schema::{DataType, Field, Schema};
1589
1590    fn test_schema() -> SchemaRef {
1591        Arc::new(Schema::new(vec![
1592            Field::new("name", DataType::Utf8, true),
1593            Field::new("value", DataType::Int64, true),
1594        ]))
1595    }
1596
1597    fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
1598        RecordBatch::try_new(
1599            test_schema(),
1600            vec![
1601                Arc::new(StringArray::from(
1602                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
1603                )),
1604                Arc::new(Int64Array::from(values.to_vec())),
1605            ],
1606        )
1607        .unwrap()
1608    }
1609
1610    // --- FixpointState dedup tests ---
1611
1612    #[tokio::test]
1613    async fn test_fixpoint_state_empty_facts_adds_all() {
1614        let schema = test_schema();
1615        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None);
1616
1617        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
1618        let changed = state.merge_delta(vec![batch], None).await.unwrap();
1619
1620        assert!(changed);
1621        assert_eq!(state.all_facts().len(), 1);
1622        assert_eq!(state.all_facts()[0].num_rows(), 3);
1623        assert_eq!(state.all_delta().len(), 1);
1624        assert_eq!(state.all_delta()[0].num_rows(), 3);
1625    }
1626
1627    #[tokio::test]
1628    async fn test_fixpoint_state_exact_duplicates_excluded() {
1629        let schema = test_schema();
1630        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None);
1631
1632        let batch1 = make_batch(&["a", "b"], &[1, 2]);
1633        state.merge_delta(vec![batch1], None).await.unwrap();
1634
1635        // Same rows again
1636        let batch2 = make_batch(&["a", "b"], &[1, 2]);
1637        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
1638        assert!(!changed);
1639        assert!(
1640            state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
1641        );
1642    }
1643
1644    #[tokio::test]
1645    async fn test_fixpoint_state_partial_overlap() {
1646        let schema = test_schema();
1647        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None);
1648
1649        let batch1 = make_batch(&["a", "b"], &[1, 2]);
1650        state.merge_delta(vec![batch1], None).await.unwrap();
1651
1652        // "a":1 is duplicate, "c":3 is new
1653        let batch2 = make_batch(&["a", "c"], &[1, 3]);
1654        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
1655        assert!(changed);
1656
1657        // Delta should have only "c":3
1658        let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
1659        assert_eq!(delta_rows, 1);
1660
1661        // Total facts: a:1, b:2, c:3
1662        let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
1663        assert_eq!(total_rows, 3);
1664    }
1665
1666    #[tokio::test]
1667    async fn test_fixpoint_state_convergence() {
1668        let schema = test_schema();
1669        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None);
1670
1671        let batch = make_batch(&["a"], &[1]);
1672        state.merge_delta(vec![batch], None).await.unwrap();
1673
1674        // Empty candidates → converged
1675        let changed = state.merge_delta(vec![], None).await.unwrap();
1676        assert!(!changed);
1677        assert!(state.is_converged());
1678    }
1679
1680    // --- RowDedupState tests ---
1681
1682    #[test]
1683    fn test_row_dedup_persistent_across_calls() {
1684        // RowDedupState should remember rows from the first call so the second
1685        // call does not re-accept them (O(M) per iteration, no facts re-scan).
1686        let schema = test_schema();
1687        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
1688
1689        let batch1 = make_batch(&["a", "b"], &[1, 2]);
1690        let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
1691        // First call: both rows are new.
1692        let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
1693        assert_eq!(rows1, 2);
1694
1695        // Second call with same rows: seen set already has them → empty delta.
1696        let batch2 = make_batch(&["a", "b"], &[1, 2]);
1697        let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
1698        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
1699        assert_eq!(rows2, 0);
1700
1701        // Third call with one old + one new: only the new row is returned.
1702        let batch3 = make_batch(&["a", "c"], &[1, 3]);
1703        let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
1704        let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
1705        assert_eq!(rows3, 1);
1706    }
1707
1708    #[test]
1709    fn test_row_dedup_null_handling() {
1710        use arrow_array::StringArray;
1711        use arrow_schema::{DataType, Field, Schema};
1712
1713        let schema: SchemaRef = Arc::new(Schema::new(vec![
1714            Field::new("a", DataType::Utf8, true),
1715            Field::new("b", DataType::Int64, true),
1716        ]));
1717        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
1718
1719        // Two rows: (NULL, 1) and (NULL, 1) — same NULLs → duplicate.
1720        let batch_nulls = RecordBatch::try_new(
1721            Arc::clone(&schema),
1722            vec![
1723                Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
1724                Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
1725            ],
1726        )
1727        .unwrap();
1728        let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
1729        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
1730        assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
1731
1732        // (NULL, 2) — NULL in same col but different non-null col → distinct.
1733        let batch_diff = RecordBatch::try_new(
1734            Arc::clone(&schema),
1735            vec![
1736                Arc::new(StringArray::from(vec![None::<&str>])),
1737                Arc::new(arrow_array::Int64Array::from(vec![2i64])),
1738            ],
1739        )
1740        .unwrap();
1741        let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
1742        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
1743        assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
1744    }
1745
1746    #[test]
1747    fn test_row_dedup_within_candidate_dedup() {
1748        // Duplicate rows within a single candidate batch should be collapsed to one.
1749        let schema = test_schema();
1750        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
1751
1752        // Batch with three rows: a:1, a:1, b:2 — "a:1" appears twice.
1753        let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
1754        let delta = rd.compute_delta(&[batch], &schema).unwrap();
1755        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
1756        assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
1757    }
1758
1759    // --- Float rounding tests ---
1760
1761    #[test]
1762    fn test_round_float_columns_near_duplicates() {
1763        let schema = Arc::new(Schema::new(vec![
1764            Field::new("name", DataType::Utf8, true),
1765            Field::new("dist", DataType::Float64, true),
1766        ]));
1767        let batch = RecordBatch::try_new(
1768            Arc::clone(&schema),
1769            vec![
1770                Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
1771                Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
1772            ],
1773        )
1774        .unwrap();
1775
1776        let rounded = round_float_columns(&[batch]);
1777        assert_eq!(rounded.len(), 1);
1778        let col = rounded[0]
1779            .column(1)
1780            .as_any()
1781            .downcast_ref::<Float64Array>()
1782            .unwrap();
1783        // Both should round to same value
1784        assert_eq!(col.value(0), col.value(1));
1785    }
1786
1787    // --- DerivedScanRegistry tests ---
1788
1789    #[test]
1790    fn test_registry_write_read_round_trip() {
1791        let schema = test_schema();
1792        let data = Arc::new(RwLock::new(Vec::new()));
1793        let mut reg = DerivedScanRegistry::new();
1794        reg.add(DerivedScanEntry {
1795            scan_index: 0,
1796            rule_name: "reachable".into(),
1797            is_self_ref: true,
1798            data: Arc::clone(&data),
1799            schema: Arc::clone(&schema),
1800        });
1801
1802        let batch = make_batch(&["x"], &[42]);
1803        reg.write_data(0, vec![batch.clone()]);
1804
1805        let entry = reg.get(0).unwrap();
1806        let guard = entry.data.read();
1807        assert_eq!(guard.len(), 1);
1808        assert_eq!(guard[0].num_rows(), 1);
1809    }
1810
1811    #[test]
1812    fn test_registry_entries_for_rule() {
1813        let schema = test_schema();
1814        let mut reg = DerivedScanRegistry::new();
1815        reg.add(DerivedScanEntry {
1816            scan_index: 0,
1817            rule_name: "r1".into(),
1818            is_self_ref: true,
1819            data: Arc::new(RwLock::new(Vec::new())),
1820            schema: Arc::clone(&schema),
1821        });
1822        reg.add(DerivedScanEntry {
1823            scan_index: 1,
1824            rule_name: "r2".into(),
1825            is_self_ref: false,
1826            data: Arc::new(RwLock::new(Vec::new())),
1827            schema: Arc::clone(&schema),
1828        });
1829        reg.add(DerivedScanEntry {
1830            scan_index: 2,
1831            rule_name: "r1".into(),
1832            is_self_ref: false,
1833            data: Arc::new(RwLock::new(Vec::new())),
1834            schema: Arc::clone(&schema),
1835        });
1836
1837        assert_eq!(reg.entries_for_rule("r1").len(), 2);
1838        assert_eq!(reg.entries_for_rule("r2").len(), 1);
1839        assert_eq!(reg.entries_for_rule("r3").len(), 0);
1840    }
1841
1842    // --- MonotonicAggState tests ---
1843
1844    #[test]
1845    fn test_monotonic_agg_update_and_stability() {
1846        use crate::query::df_graph::locy_fold::FoldAggKind;
1847
1848        let bindings = vec![MonotonicFoldBinding {
1849            fold_name: "total".into(),
1850            kind: FoldAggKind::Sum,
1851            input_col_index: 1,
1852        }];
1853        let mut agg = MonotonicAggState::new(bindings);
1854
1855        // First update
1856        let batch = make_batch(&["a"], &[10]);
1857        agg.snapshot();
1858        let changed = agg.update(&[0], &[batch]);
1859        assert!(changed);
1860        assert!(!agg.is_stable()); // changed since snapshot
1861
1862        // Snapshot and check stability with no new data
1863        agg.snapshot();
1864        let changed = agg.update(&[0], &[]);
1865        assert!(!changed);
1866        assert!(agg.is_stable());
1867    }
1868
1869    // --- Memory limit test ---
1870
1871    #[tokio::test]
1872    async fn test_memory_limit_exceeded() {
1873        let schema = test_schema();
1874        // Set a tiny limit
1875        let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None);
1876
1877        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
1878        let result = state.merge_delta(vec![batch], None).await;
1879        assert!(result.is_err());
1880        let err = result.unwrap_err().to_string();
1881        assert!(err.contains("memory limit"), "Error was: {}", err);
1882    }
1883
1884    // --- FixpointStream lifecycle test ---
1885
1886    #[tokio::test]
1887    async fn test_fixpoint_stream_emitting() {
1888        use futures::StreamExt;
1889
1890        let schema = test_schema();
1891        let batch1 = make_batch(&["a"], &[1]);
1892        let batch2 = make_batch(&["b"], &[2]);
1893
1894        let metrics = ExecutionPlanMetricsSet::new();
1895        let baseline = BaselineMetrics::new(&metrics, 0);
1896
1897        let mut stream = FixpointStream {
1898            state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
1899            schema,
1900            metrics: baseline,
1901        };
1902
1903        let stream = Pin::new(&mut stream);
1904        let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
1905
1906        assert_eq!(batches.len(), 2);
1907        assert_eq!(batches[0].num_rows(), 1);
1908        assert_eq!(batches[1].num_rows(), 1);
1909    }
1910}