Skip to main content

uni_query/query/df_graph/
locy_fold.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! FOLD operator for Locy.
5//!
6//! `FoldExec` applies fold (lattice-join) semantics: for each group of rows sharing
7//! the same KEY columns, it reduces non-key columns via their declared fold functions.
8
9use crate::query::df_graph::common::{
10    ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
11};
12use arrow_array::builder::Float64Builder;
13use arrow_array::{Array, RecordBatch};
14use arrow_schema::{DataType, Field, Schema, SchemaRef};
15use datafusion::common::Result as DFResult;
16use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
17use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
18use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
19use datafusion::scalar::ScalarValue;
20use futures::{Stream, TryStreamExt};
21use smol_str::SmolStr;
22use std::any::Any;
23use std::collections::HashMap;
24use std::fmt;
25use std::pin::Pin;
26use std::sync::{Arc, OnceLock};
27use std::task::{Context, Poll};
28use uni_locy::SemiringKind;
29use uni_plugin::traits::locy::{FoldContext, FoldSemiring, LocyAggregate};
30
31use super::locy_explain::ProofTerm;
32
33/// Plugin-aware resolution of an aggregate name to a [`uni_plugin::traits::locy::LocyAggregate`].
34///
35/// Looks up `name` (case-folded) against the supplied [`uni_plugin::PluginRegistry`]
36/// under the reserved built-in namespace. Returns `None` if no plugin claims
37/// the aggregate.
38///
39/// Accepts legacy grammar aliases: the bare (`SUM`/`MAX`/`MIN`/`COUNT`) and
40/// `M`-prefixed (`MSUM`/`MMAX`/`MMIN`/`MCOUNT`) forms, plus `NOR`→`MNOR`,
41/// `PROD`→`MPROD`. `COUNTALL` (the zero-argument `COUNT()`/`MCOUNT()` form)
42/// resolves to its own dedicated aggregate, distinct from null-skipping
43/// `COUNT`.
44///
45/// # Examples
46///
47/// ```ignore
48/// use uni_query::query::df_graph::locy_fold::{default_locy_plugin_registry, resolve_locy_aggregate};
49/// let r = default_locy_plugin_registry();
50/// let agg = resolve_locy_aggregate(&r, "SUM");
51/// assert!(agg.is_some());
52/// ```
53/// Returns the monotonicity verdict for an aggregate name resolved through
54/// the supplied [`uni_plugin::PluginRegistry`].
55///
56/// `Some(true)` — registered monotone aggregate (`Semilattice.monotone_join`
57/// is `true`), sound in recursive Locy strata. `Some(false)` — registered
58/// but non-monotone, must be rejected in recursion. `None` — unregistered.
59///
60/// Aliases (`MSUM`/`MMAX`/`MMIN`/`MCOUNT`/`NOR`/`PROD`/`COUNTALL`) are
61/// canonicalized by [`resolve_locy_aggregate`] before lookup.
62#[must_use]
63pub fn is_monotonic_aggregate(registry: &uni_plugin::PluginRegistry, name: &str) -> Option<bool> {
64    resolve_locy_aggregate(registry, name).map(|e| e.aggregate.semilattice().monotone_join)
65}
66
67#[must_use]
68pub fn resolve_locy_aggregate(
69    registry: &uni_plugin::PluginRegistry,
70    name: &str,
71) -> Option<std::sync::Arc<uni_plugin::registry::LocyAggregateEntry>> {
72    let canonical = match name.to_uppercase().as_str() {
73        "MMAX" => "MAX".to_owned(),
74        "MMIN" => "MIN".to_owned(),
75        "MCOUNT" => "COUNT".to_owned(),
76        "NOR" => "MNOR".to_owned(),
77        "PROD" => "MPROD".to_owned(),
78        other => other.to_owned(),
79    };
80    let qname = uni_plugin::QName::builtin(canonical);
81    // M8.6 dual-consult: session-local first (if a Session has set the
82    // task-local via `scoped_with_session_plugin_registry`), then fall
83    // back to the caller-supplied (instance) registry. This makes
84    // session-scoped Locy aggregates visible without changing any
85    // caller of `resolve_locy_aggregate`.
86    if let Some(session_pr) = crate::current_session_plugin_registry()
87        && let Some(entry) = session_pr.locy_aggregate(&qname)
88    {
89        return Some(entry);
90    }
91    registry.locy_aggregate(&qname)
92}
93
94/// Returns a process-wide [`uni_plugin::PluginRegistry`] pre-populated with
95/// the built-in Locy aggregates from `uni-plugin-builtin`.
96///
97/// Used by [`crate::query::df_planner::HybridPhysicalPlanner`] as a default
98/// when the host has not supplied its own registry. Lazily initialized
99/// at first call and shared thereafter.
100///
101/// # Panics
102///
103/// Panics only on framework-internal invariants: capability gating, qname
104/// validation, or duplicate commit. The built-in registration set is fixed
105/// and cannot trigger any of these at runtime.
106#[must_use]
107pub fn default_locy_plugin_registry() -> Arc<uni_plugin::PluginRegistry> {
108    static REGISTRY: OnceLock<Arc<uni_plugin::PluginRegistry>> = OnceLock::new();
109    Arc::clone(REGISTRY.get_or_init(|| {
110        let registry = uni_plugin::PluginRegistry::new();
111        let plugin_id = uni_plugin::PluginId::new(uni_plugin::QName::BUILTIN_NS);
112        let caps = uni_plugin::CapabilitySet::from_iter_of([uni_plugin::Capability::LocyAggregate]);
113        let mut r = uni_plugin::PluginRegistrar::new(plugin_id, &caps, &registry);
114        uni_plugin_builtin::locy_aggregates::register_into(&mut r)
115            .expect("built-in locy aggregates register");
116        r.commit_to_registry().expect("commit built-in aggregates");
117        Arc::new(registry)
118    }))
119}
120
121/// A single FOLD binding: aggregate an input column into an output column.
122///
123/// Carries the canonical aggregate name (used as a sentinel for `COUNTALL`
124/// and for batch-path dispatch in [`FoldExec`]) alongside the resolved
125/// [`LocyAggregate`] trait object (used by the fixpoint runtime). The name
126/// is one of: `SUM`, `MIN`, `MAX`, `COUNT`, `COUNTALL`, `AVG`, `COLLECT`,
127/// `MNOR`, `MPROD`.
128#[derive(Debug, Clone)]
129pub struct FoldBinding {
130    pub output_name: String,
131    /// Canonical uppercase aggregate name.
132    pub name: SmolStr,
133    /// Resolved aggregate trait object (registry-backed).
134    pub aggregate: Arc<dyn LocyAggregate>,
135    pub input_col_index: usize,
136    /// Column name for name-based resolution (more robust than positional index).
137    /// `None` for `COUNTALL` which has no input column.
138    pub input_col_name: Option<String>,
139}
140
141/// DataFusion `ExecutionPlan` that applies FOLD semantics.
142///
143/// Groups rows by KEY columns and computes aggregates (SUM, MAX, MIN, COUNT, AVG, COLLECT)
144/// for each fold binding. Output schema is KEY columns + fold output columns.
145#[derive(Debug)]
146pub struct FoldExec {
147    input: Arc<dyn ExecutionPlan>,
148    key_indices: Vec<usize>,
149    fold_bindings: Vec<FoldBinding>,
150    strict_probability_domain: bool,
151    probability_epsilon: f64,
152    /// Active probability semiring. `AddMultProb` (the default) preserves
153    /// byte-identical Phase 1/2 noisy-OR / product behavior. `MaxMinProb`
154    /// (Viterbi) is opt-in and produces fuzzy-truth values; callers up the
155    /// stack also emit `FuzzyNotProbabilistic` on PROB-bearing rules.
156    semiring_kind: SemiringKind,
157    /// Phase D D-C0: under `SemiringKind::TopKProofs`, MNOR aggregates use
158    /// DNF inclusion-exclusion over the row's support chain (lifted from
159    /// the provenance tracker) rather than independence-mode noisy-OR.
160    /// `None` for non-TopK semirings — keeps the byte-identical `f64`
161    /// path for AddMultProb / MaxMinProb.
162    provenance_tracker: Option<Arc<super::locy_explain::ProvenanceStore>>,
163    /// Phase D D-C0: top-k retention used for proof pruning. Mirrors the
164    /// fixpoint-loop config; passed through so per-group `TopKTag`
165    /// merges respect the same K as the in-loop accumulator.
166    top_k_proofs_k: usize,
167    /// Pre-computed map of body-row content hash → IS-ref support
168    /// (`Vec<ProofTerm>`) for use by `topk_dnf_disjunction`. Populated
169    /// in `apply_post_fixpoint_chain` *before* this `FoldExec` is
170    /// built, because at FOLD time the current rule's own facts are
171    /// not yet recorded in the `ProvenanceStore` (which is keyed by
172    /// post-YIELD hashes anyway). `None` for non-TopK semirings and
173    /// for legacy callers.
174    body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>>,
175    schema: SchemaRef,
176    properties: Arc<PlanProperties>,
177    metrics: ExecutionPlanMetricsSet,
178}
179
180impl FoldExec {
181    /// Create a new `FoldExec`.
182    ///
183    /// # Arguments
184    /// * `input` - Child execution plan
185    /// * `key_indices` - Indices of KEY columns for grouping
186    /// * `fold_bindings` - Aggregate bindings (output name, kind, input col index)
187    pub fn new(
188        input: Arc<dyn ExecutionPlan>,
189        key_indices: Vec<usize>,
190        fold_bindings: Vec<FoldBinding>,
191        strict_probability_domain: bool,
192        probability_epsilon: f64,
193    ) -> Self {
194        Self::new_with_semiring(
195            input,
196            key_indices,
197            fold_bindings,
198            strict_probability_domain,
199            probability_epsilon,
200            SemiringKind::AddMultProb,
201        )
202    }
203
204    /// Variant taking an explicit [`SemiringKind`]. Existing callers can
205    /// keep using [`FoldExec::new`] (which defaults to `AddMultProb`); the
206    /// fixpoint planner uses this form to thread the configured semiring
207    /// from [`uni_locy::LocyConfig::resolve`].
208    pub fn new_with_semiring(
209        input: Arc<dyn ExecutionPlan>,
210        key_indices: Vec<usize>,
211        fold_bindings: Vec<FoldBinding>,
212        strict_probability_domain: bool,
213        probability_epsilon: f64,
214        semiring_kind: SemiringKind,
215    ) -> Self {
216        Self::new_with_topk(
217            input,
218            key_indices,
219            fold_bindings,
220            strict_probability_domain,
221            probability_epsilon,
222            semiring_kind,
223            None,
224            0,
225            None,
226        )
227    }
228
229    /// Phase D D-C0: variant that threads the provenance tracker and
230    /// `top_k_proofs` config so MNOR under `SemiringKind::TopKProofs`
231    /// can resolve each row's IS-ref support chain into a `Proof` and
232    /// aggregate via DNF inclusion-exclusion.
233    ///
234    /// `body_support_map` (Phase D D-C0 follow-up) is a pre-computed
235    /// `body_row_hash → Vec<ProofTerm>` map, populated by
236    /// `apply_post_fixpoint_chain` for TopKProofs rules. The tracker
237    /// alone is insufficient — its entries are keyed by post-YIELD
238    /// row hashes and are only populated after FOLD runs; the pre-fold
239    /// body rows seen here would never hit. The map closes that gap.
240    #[allow(clippy::too_many_arguments)]
241    pub fn new_with_topk(
242        input: Arc<dyn ExecutionPlan>,
243        key_indices: Vec<usize>,
244        fold_bindings: Vec<FoldBinding>,
245        strict_probability_domain: bool,
246        probability_epsilon: f64,
247        semiring_kind: SemiringKind,
248        provenance_tracker: Option<Arc<super::locy_explain::ProvenanceStore>>,
249        top_k_proofs_k: usize,
250        body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>>,
251    ) -> Self {
252        let input_schema = input.schema();
253        let schema = Self::build_output_schema(&input_schema, &key_indices, &fold_bindings);
254        let properties = compute_plan_properties(Arc::clone(&schema));
255
256        Self {
257            input,
258            key_indices,
259            fold_bindings,
260            strict_probability_domain,
261            probability_epsilon,
262            semiring_kind,
263            provenance_tracker,
264            top_k_proofs_k,
265            body_support_map,
266            schema,
267            properties,
268            metrics: ExecutionPlanMetricsSet::new(),
269        }
270    }
271
272    fn build_output_schema(
273        input_schema: &SchemaRef,
274        key_indices: &[usize],
275        fold_bindings: &[FoldBinding],
276    ) -> SchemaRef {
277        let mut fields = Vec::new();
278
279        // Key columns preserve original types
280        for &ki in key_indices {
281            fields.push(Arc::new(input_schema.field(ki).clone()));
282        }
283
284        // Fold output columns — type derived from the aggregate trait. The
285        // input column type is resolved first (name-then-index) so
286        // type-preserving aggregates (`MIN`/`MAX`) can return it.
287        for binding in fold_bindings {
288            let idx = binding
289                .input_col_name
290                .as_ref()
291                .and_then(|name| input_schema.index_of(name).ok())
292                .unwrap_or(binding.input_col_index);
293            let input_type = if idx < input_schema.fields().len() {
294                input_schema.field(idx).data_type().clone()
295            } else {
296                DataType::Float64
297            };
298            let output_type = binding.aggregate.output_type_for_input(&input_type);
299            fields.push(Arc::new(Field::new(
300                &binding.output_name,
301                output_type,
302                true,
303            )));
304        }
305
306        Arc::new(Schema::new(fields))
307    }
308}
309
310impl DisplayAs for FoldExec {
311    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312        write!(
313            f,
314            "FoldExec: key_indices={:?}, bindings={:?}",
315            self.key_indices, self.fold_bindings
316        )
317    }
318}
319
320impl ExecutionPlan for FoldExec {
321    fn name(&self) -> &str {
322        "FoldExec"
323    }
324
325    fn as_any(&self) -> &dyn Any {
326        self
327    }
328
329    fn schema(&self) -> SchemaRef {
330        Arc::clone(&self.schema)
331    }
332
333    fn properties(&self) -> &Arc<PlanProperties> {
334        &self.properties
335    }
336
337    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
338        vec![&self.input]
339    }
340
341    fn with_new_children(
342        self: Arc<Self>,
343        children: Vec<Arc<dyn ExecutionPlan>>,
344    ) -> DFResult<Arc<dyn ExecutionPlan>> {
345        if children.len() != 1 {
346            return Err(datafusion::error::DataFusionError::Plan(
347                "FoldExec requires exactly one child".to_string(),
348            ));
349        }
350        Ok(Arc::new(Self::new_with_topk(
351            Arc::clone(&children[0]),
352            self.key_indices.clone(),
353            self.fold_bindings.clone(),
354            self.strict_probability_domain,
355            self.probability_epsilon,
356            self.semiring_kind,
357            self.provenance_tracker.as_ref().map(Arc::clone),
358            self.top_k_proofs_k,
359            self.body_support_map.as_ref().map(Arc::clone),
360        )))
361    }
362
363    fn execute(
364        &self,
365        partition: usize,
366        context: Arc<TaskContext>,
367    ) -> DFResult<SendableRecordBatchStream> {
368        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
369        let metrics = BaselineMetrics::new(&self.metrics, partition);
370        let key_indices = self.key_indices.clone();
371        let fold_bindings = self.fold_bindings.clone();
372        let strict = self.strict_probability_domain;
373        let epsilon = self.probability_epsilon;
374        let semiring_kind = self.semiring_kind;
375        let _provenance_tracker = self.provenance_tracker.as_ref().map(Arc::clone);
376        let top_k_proofs_k = self.top_k_proofs_k;
377        let body_support_map = self.body_support_map.as_ref().map(Arc::clone);
378        let output_schema = Arc::clone(&self.schema);
379        let input_schema = self.input.schema();
380
381        let fut = async move {
382            let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
383
384            if batches.is_empty() {
385                return Ok(RecordBatch::new_empty(output_schema));
386            }
387
388            // Use the actual batch schema (may differ from pre-computed input_schema
389            // after schema reconciliation in schemaless mode).
390            let actual_schema = batches
391                .first()
392                .map(|b| b.schema())
393                .unwrap_or(input_schema.clone());
394            let batch =
395                arrow::compute::concat_batches(&actual_schema, &batches).map_err(arrow_err)?;
396
397            if batch.num_rows() == 0 {
398                return Ok(RecordBatch::new_empty(output_schema));
399            }
400
401            // Group by key columns → row indices, preserving insertion order
402            let mut groups: HashMap<Vec<ScalarKey>, Vec<usize>> = HashMap::new();
403            let mut ordered_keys: Vec<Vec<ScalarKey>> = Vec::new();
404            for row_idx in 0..batch.num_rows() {
405                let key = extract_scalar_key(&batch, &key_indices, row_idx);
406                let entry = groups.entry(key.clone());
407                if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
408                    ordered_keys.push(key);
409                }
410                entry.or_default().push(row_idx);
411            }
412
413            let num_groups = ordered_keys.len();
414
415            // Build output columns
416            let mut output_columns: Vec<arrow_array::ArrayRef> = Vec::new();
417
418            // Key columns: take from first row of each group
419            for &ki in &key_indices {
420                if ki >= batch.num_columns() {
421                    continue; // Skip invalid indices after schema reconciliation
422                }
423                let col = batch.column(ki);
424                let first_indices: Vec<u32> =
425                    ordered_keys.iter().map(|k| groups[k][0] as u32).collect();
426                let idx_array = arrow_array::UInt32Array::from(first_indices);
427                let taken = arrow::compute::take(col.as_ref(), &idx_array, None).map_err(|e| {
428                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
429                })?;
430                output_columns.push(taken);
431            }
432
433            // Per-fold evaluation context. `TopKProofs` / `BddExact` are
434            // provenance specializations handled above the aggregate (see
435            // `compute_fold_aggregate`); the aggregate trait only sees the
436            // two value-level combinators.
437            let cx = FoldContext {
438                strict,
439                epsilon,
440                semiring: match semiring_kind {
441                    SemiringKind::MaxMinProb => FoldSemiring::MaxMin,
442                    _ => FoldSemiring::AddMult,
443                },
444            };
445
446            // Fold binding columns: compute aggregates per group via the
447            // resolved `LocyAggregate` trait object.
448            for binding in &fold_bindings {
449                let col: Arc<dyn Array> = if binding.name.as_str() == "COUNTALL" {
450                    // COUNTALL has no input column — the aggregate ignores it.
451                    Arc::new(arrow_array::Int64Array::from(vec![0i64; batch.num_rows()]))
452                } else {
453                    // Resolve input column: prefer name-based lookup, fall back to index.
454                    let resolved_idx = binding
455                        .input_col_name
456                        .as_ref()
457                        .and_then(|name| batch.schema().index_of(name).ok())
458                        .unwrap_or(binding.input_col_index);
459                    if resolved_idx < batch.num_columns() {
460                        Arc::clone(batch.column(resolved_idx))
461                    } else {
462                        // Column not found — use zeros as fallback
463                        Arc::new(arrow_array::Float64Array::from(vec![
464                            0.0f64;
465                            batch.num_rows()
466                        ]))
467                    }
468                };
469                let topk_ctx = if matches!(semiring_kind, SemiringKind::TopKProofs { .. }) {
470                    Some(TopKFoldCtx {
471                        k: top_k_proofs_k,
472                        batch: &batch,
473                        body_support_map: body_support_map.as_deref(),
474                    })
475                } else {
476                    None
477                };
478                let agg_col = compute_fold_aggregate(
479                    col.as_ref(),
480                    &binding.aggregate,
481                    FoldGroups {
482                        ordered_keys: &ordered_keys,
483                        groups: &groups,
484                        num_groups,
485                    },
486                    &cx,
487                    topk_ctx.as_ref(),
488                )?;
489                output_columns.push(agg_col);
490            }
491
492            RecordBatch::try_new(output_schema, output_columns).map_err(arrow_err)
493        };
494
495        Ok(Box::pin(FoldStream {
496            state: FoldStreamState::Running(Box::pin(fut)),
497            schema: Arc::clone(&self.schema),
498            metrics,
499        }))
500    }
501
502    fn metrics(&self) -> Option<MetricsSet> {
503        Some(self.metrics.clone_inner())
504    }
505}
506
507// ---------------------------------------------------------------------------
508// Aggregate computation
509// ---------------------------------------------------------------------------
510
511/// Per-key-group data threaded into `compute_fold_aggregate` —
512/// bundled to keep the aggregator's signature under the
513/// too-many-arguments threshold.
514struct FoldGroups<'a> {
515    ordered_keys: &'a [Vec<ScalarKey>],
516    groups: &'a HashMap<Vec<ScalarKey>, Vec<usize>>,
517    num_groups: usize,
518}
519
520/// Phase D D-C0: per-call context for TopKProofs-aware aggregation.
521/// Carries the K config and a pre-computed body-row → IS-ref support
522/// map so MNOR / MPROD can build per-row `Proof`s and aggregate via
523/// DNF inclusion-exclusion. The map is built in
524/// `apply_post_fixpoint_chain` before `FoldExec` is constructed; the
525/// provenance tracker is *not* sufficient at this stage because its
526/// entries are keyed by post-YIELD row hashes and the current rule's
527/// facts have not been recorded yet.
528struct TopKFoldCtx<'a> {
529    k: usize,
530    batch: &'a RecordBatch,
531    body_support_map: Option<&'a HashMap<Vec<u8>, Vec<ProofTerm>>>,
532}
533
534/// Compute one fold-output column by dispatching through the resolved
535/// [`LocyAggregate`] trait object.
536///
537/// For each key group a fresh [`LocyAggState`](uni_plugin::traits::locy::LocyAggState)
538/// is created, fed the group's rows via `ingest_indices`, and finalized; the
539/// per-group [`ScalarValue`]s are assembled into the output array (whose type
540/// follows the finalized scalars — `Int64`/`Float64` for `MIN`/`MAX` over those
541/// inputs, `LargeBinary` for `COLLECT`, etc.).
542///
543/// `TopKProofs` noisy-OR is the one provenance specialization that sits
544/// *above* the aggregate: when a TopK context is threaded and the aggregate is
545/// noisy-OR, each group folds via DNF inclusion-exclusion over its support
546/// chains instead (degrading to plain noisy-OR when no support map exists).
547///
548/// # Errors
549///
550/// Returns a [`DFResult`] error if a plugin aggregate rejects a value (e.g., a
551/// strict probability-domain violation) or if the per-group scalars cannot be
552/// assembled into an Arrow array.
553fn compute_fold_aggregate(
554    col: &dyn Array,
555    aggregate: &Arc<dyn LocyAggregate>,
556    groups_ctx: FoldGroups<'_>,
557    cx: &FoldContext,
558    topk_ctx: Option<&TopKFoldCtx<'_>>,
559) -> DFResult<arrow_array::ArrayRef> {
560    let ordered_keys = groups_ctx.ordered_keys;
561    let groups = groups_ctx.groups;
562    let num_groups = groups_ctx.num_groups;
563
564    // Phase D D-C0: TopKProofs noisy-OR uses DNF inclusion-exclusion over each
565    // row's support chain. This is a provenance specialization layered above
566    // the MNOR aggregate; it falls back to independence-mode noisy-OR when no
567    // support map is present (the common / non-recursive case).
568    if let Some(ctx) = topk_ctx
569        && aggregate.is_noisy_or()
570    {
571        let mut builder = Float64Builder::with_capacity(num_groups);
572        for key in ordered_keys {
573            builder.append_option(topk_dnf_disjunction(col, &groups[key], cx.strict, ctx)?);
574        }
575        return Ok(Arc::new(builder.finish()));
576    }
577
578    // Generic trait dispatch: one aggregate state per key group.
579    let mut scalars: Vec<ScalarValue> = Vec::with_capacity(num_groups);
580    for key in ordered_keys {
581        let mut state = aggregate.create();
582        state
583            .ingest_indices(col, &groups[key], cx)
584            .map_err(fn_error_to_df)?;
585        scalars.push(state.finalize().map_err(fn_error_to_df)?);
586    }
587    ScalarValue::iter_to_array(scalars)
588}
589
590/// Map a plugin [`FnError`](uni_plugin::FnError) to a DataFusion error,
591/// preserving the message so strict-domain text survives to the caller.
592fn fn_error_to_df(e: uni_plugin::FnError) -> datafusion::error::DataFusionError {
593    datafusion::error::DataFusionError::Execution(e.message)
594}
595
596/// Phase D D-C0: TopKProofs MNOR via DNF inclusion-exclusion over the
597/// rows' support chains. Each row contributes one `Proof` whose
598/// weight is the row's MNOR-input value and whose `base_rvs` are
599/// interned from the row's IS-ref support, resolved via
600/// `ctx.body_support_map` — a precomputed map of body-row content
601/// hash → `Vec<ProofTerm>`. The map is built in
602/// `apply_post_fixpoint_chain` before FOLD; the `ProvenanceStore`
603/// tracker cannot be used here because its entries are keyed by
604/// post-YIELD hashes and the rule's own facts haven't been recorded
605/// yet at FOLD time. Proofs are merged via `merge_top_k_runtime` (so
606/// the K config is respected and `CrossedDependency` notices ride
607/// the side-channel). The per-group output is
608/// `TopKTag.to_dnf().weight(&base_weights)` — exact when no
609/// dependency overlap exists, exact under inclusion-exclusion when
610/// shared base facts appear across proofs.
611///
612/// Rows whose body-hash isn't in the support map (e.g. rules whose
613/// MNOR runs over plain columns with no IS-ref bindings) contribute
614/// empty-support Proofs — the math degrades to plain f64 noisy-OR
615/// (independence-mode), preserving the pre-D-C0 byte-identical
616/// AddMultProb behavior.
617fn topk_dnf_disjunction(
618    col: &dyn Array,
619    indices: &[usize],
620    strict: bool,
621    ctx: &TopKFoldCtx<'_>,
622) -> DFResult<Option<f64>> {
623    use uni_locy::{BaseRv, BaseRvSet, Proof};
624
625    let batch = ctx.batch;
626    let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
627    let mut interner: HashMap<Vec<u8>, BaseRv> = HashMap::new();
628    let mut next_rv: u32 = 0;
629    let mut base_weights: HashMap<BaseRv, f64> = HashMap::new();
630    let mut proofs: Vec<Proof> = Vec::with_capacity(indices.len());
631
632    for &i in indices {
633        if col.is_null(i) {
634            continue;
635        }
636        // Row's MNOR-input value (e.g. an IS-ref edge probability).
637        let val = match col.as_any().downcast_ref::<arrow_array::Float64Array>() {
638            Some(arr) => arr.value(i),
639            None => match col.as_any().downcast_ref::<arrow_array::Int64Array>() {
640                Some(arr) => arr.value(i) as f64,
641                None => continue,
642            },
643        };
644        if strict && !(0.0..=1.0).contains(&val) {
645            return Err(datafusion::error::DataFusionError::Execution(format!(
646                "strict_probability_domain: MNOR input {val} outside [0,1] under TopKProofs"
647            )));
648        }
649        let weight = val.clamp(0.0, 1.0);
650
651        // Resolve row's IS-ref support via the precomputed body-row map;
652        // intern base facts into BaseRvs.
653        let fact_hash = super::locy_fixpoint::fact_hash_key(batch, &all_indices, i);
654        let mut base_rvs = BaseRvSet::empty();
655        if let Some(support) = ctx.body_support_map.and_then(|m| m.get(&fact_hash)) {
656            for term in support {
657                let rv = *interner
658                    .entry(term.base_fact_id.clone())
659                    .or_insert_with(|| {
660                        let r = BaseRv(next_rv);
661                        next_rv += 1;
662                        r
663                    });
664                base_rvs.insert(rv);
665            }
666        }
667        // Single-row Proof. Base weights for the DNF: assign the row's
668        // weight to each base RV under it (when no support exists,
669        // base_rvs is empty and the proof's weight stands alone).
670        // When multiple rows share the same base RV (shared-proof case),
671        // take max — deterministic regardless of row visit order, and
672        // a conservative upper bound for the noisy-OR DNF.
673        if base_rvs.iter().count() > 0 {
674            for rv in base_rvs.iter() {
675                base_weights
676                    .entry(rv)
677                    .and_modify(|w| {
678                        if weight > *w {
679                            *w = weight;
680                        }
681                    })
682                    .or_insert(weight);
683            }
684        }
685        proofs.push(Proof {
686            weight,
687            base_rvs,
688            neural_calls: Vec::new(),
689        });
690    }
691    if proofs.is_empty() {
692        return Ok(None);
693    }
694    // When NO proof carries base_rvs (no IS-ref support visible —
695    // the rule's MNOR runs over plain columns, not derived facts),
696    // fall back to independence-mode noisy-OR. Going through
697    // `merge_top_k` here is wrong: it dedupes by dependency_key,
698    // collapsing all empty-base_rvs proofs into one max-weight
699    // proof. Plain noisy-OR over each row's weight preserves the
700    // pre-D-C0 AddMultProb behavior byte-identically.
701    if base_weights.is_empty() {
702        let mut complement = 1.0;
703        for p in &proofs {
704            complement *= 1.0 - p.weight;
705        }
706        return Ok(Some((1.0 - complement).clamp(0.0, 1.0)));
707    }
708    // At least one proof carries base_rvs — DNF inclusion-exclusion
709    // is meaningful. Merge top-K (which dedupes by dependency_key
710    // intentionally — shared bases ARE the same dependency) and
711    // compute exact (or top-K-approximated) probability via the
712    // DNF.
713    let k = if ctx.k == 0 { proofs.len() } else { ctx.k };
714    let (kept, _notice) = uni_locy::merge_top_k_runtime(Vec::new(), proofs, k);
715    let tag = uni_locy::TopKTag { proofs: kept };
716    Ok(Some(tag.to_dnf().weight(&base_weights)))
717}
718
719// ---------------------------------------------------------------------------
720// Stream implementation
721// ---------------------------------------------------------------------------
722
723enum FoldStreamState {
724    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
725    Done,
726}
727
728struct FoldStream {
729    state: FoldStreamState,
730    schema: SchemaRef,
731    metrics: BaselineMetrics,
732}
733
734impl Stream for FoldStream {
735    type Item = DFResult<RecordBatch>;
736
737    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
738        let metrics = self.metrics.clone();
739        let _timer = metrics.elapsed_compute().timer();
740        match &mut self.state {
741            FoldStreamState::Running(fut) => match fut.as_mut().poll(cx) {
742                Poll::Ready(Ok(batch)) => {
743                    self.metrics.record_output(batch.num_rows());
744                    self.state = FoldStreamState::Done;
745                    Poll::Ready(Some(Ok(batch)))
746                }
747                Poll::Ready(Err(e)) => {
748                    self.state = FoldStreamState::Done;
749                    Poll::Ready(Some(Err(e)))
750                }
751                Poll::Pending => Poll::Pending,
752            },
753            FoldStreamState::Done => Poll::Ready(None),
754        }
755    }
756}
757
758impl RecordBatchStream for FoldStream {
759    fn schema(&self) -> SchemaRef {
760        Arc::clone(&self.schema)
761    }
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767    use arrow_array::{Float64Array, Int64Array, StringArray};
768    use arrow_schema::{DataType, Field, Schema};
769    use datafusion::physical_plan::memory::MemoryStream;
770    use datafusion::prelude::SessionContext;
771
772    /// Direct construction of a built-in `LocyAggregate` trait object for use
773    /// in `FoldBinding` test fixtures. Avoids registry plumbing in tests that
774    /// only need a working aggregate.
775    fn builtin_agg(name: &str) -> Arc<dyn LocyAggregate> {
776        use uni_plugin_builtin::locy_aggregates::*;
777        match name {
778            "SUM" | "MSUM" => Arc::new(SumAgg),
779            "MAX" | "MMAX" => Arc::new(MaxAgg),
780            "MIN" | "MMIN" => Arc::new(MinAgg),
781            "COUNT" | "COUNTALL" | "MCOUNT" => Arc::new(CountAgg),
782            "AVG" => Arc::new(AvgAgg),
783            "COLLECT" => Arc::new(CollectAgg),
784            "MNOR" | "NOR" => Arc::new(MnorAgg),
785            "MPROD" | "PROD" => Arc::new(MprodAgg),
786            other => panic!("unknown test aggregate `{other}`"),
787        }
788    }
789
790    fn make_test_batch(names: Vec<&str>, values: Vec<f64>) -> RecordBatch {
791        let schema = Arc::new(Schema::new(vec![
792            Field::new("name", DataType::Utf8, true),
793            Field::new("value", DataType::Float64, true),
794        ]));
795        RecordBatch::try_new(
796            schema,
797            vec![
798                Arc::new(StringArray::from(
799                    names.into_iter().map(Some).collect::<Vec<_>>(),
800                )),
801                Arc::new(Float64Array::from(values)),
802            ],
803        )
804        .unwrap()
805    }
806
807    fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
808        let schema = batch.schema();
809        Arc::new(TestMemoryExec {
810            batches: vec![batch],
811            schema: schema.clone(),
812            properties: compute_plan_properties(schema),
813        })
814    }
815
816    #[derive(Debug)]
817    struct TestMemoryExec {
818        batches: Vec<RecordBatch>,
819        schema: SchemaRef,
820        properties: Arc<PlanProperties>,
821    }
822
823    impl DisplayAs for TestMemoryExec {
824        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
825            write!(f, "TestMemoryExec")
826        }
827    }
828
829    impl ExecutionPlan for TestMemoryExec {
830        fn name(&self) -> &str {
831            "TestMemoryExec"
832        }
833        fn as_any(&self) -> &dyn Any {
834            self
835        }
836        fn schema(&self) -> SchemaRef {
837            Arc::clone(&self.schema)
838        }
839        fn properties(&self) -> &Arc<PlanProperties> {
840            &self.properties
841        }
842        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
843            vec![]
844        }
845        fn with_new_children(
846            self: Arc<Self>,
847            _children: Vec<Arc<dyn ExecutionPlan>>,
848        ) -> DFResult<Arc<dyn ExecutionPlan>> {
849            Ok(self)
850        }
851        fn execute(
852            &self,
853            _partition: usize,
854            _context: Arc<TaskContext>,
855        ) -> DFResult<SendableRecordBatchStream> {
856            Ok(Box::pin(MemoryStream::try_new(
857                self.batches.clone(),
858                Arc::clone(&self.schema),
859                None,
860            )?))
861        }
862    }
863
864    async fn execute_fold(
865        input: Arc<dyn ExecutionPlan>,
866        key_indices: Vec<usize>,
867        fold_bindings: Vec<FoldBinding>,
868    ) -> RecordBatch {
869        let exec = FoldExec::new(input, key_indices, fold_bindings, false, 1e-15);
870        let ctx = SessionContext::new();
871        let task_ctx = ctx.task_ctx();
872        let stream = exec.execute(0, task_ctx).unwrap();
873        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
874            .await
875            .unwrap();
876        if batches.is_empty() {
877            RecordBatch::new_empty(exec.schema())
878        } else {
879            arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
880        }
881    }
882
883    #[tokio::test]
884    async fn test_sum_single_group() {
885        let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
886        let input = make_memory_exec(batch);
887        let result = execute_fold(
888            input,
889            vec![0],
890            vec![FoldBinding {
891                output_name: "total".to_string(),
892                name: SmolStr::new_static("SUM"),
893                aggregate: builtin_agg("SUM"),
894                input_col_index: 1,
895                input_col_name: None,
896            }],
897        )
898        .await;
899
900        assert_eq!(result.num_rows(), 1);
901        let totals = result
902            .column(1)
903            .as_any()
904            .downcast_ref::<Float64Array>()
905            .unwrap();
906        assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
907    }
908
909    #[tokio::test]
910    async fn test_count_non_null() {
911        let schema = Arc::new(Schema::new(vec![
912            Field::new("name", DataType::Utf8, true),
913            Field::new("value", DataType::Float64, true),
914        ]));
915        let batch = RecordBatch::try_new(
916            schema,
917            vec![
918                Arc::new(StringArray::from(vec![Some("a"), Some("a"), Some("a")])),
919                Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])),
920            ],
921        )
922        .unwrap();
923        let input = make_memory_exec(batch);
924        let result = execute_fold(
925            input,
926            vec![0],
927            vec![FoldBinding {
928                output_name: "cnt".to_string(),
929                name: SmolStr::new_static("COUNT"),
930                aggregate: builtin_agg("COUNT"),
931                input_col_index: 1,
932                input_col_name: None,
933            }],
934        )
935        .await;
936
937        assert_eq!(result.num_rows(), 1);
938        let counts = result
939            .column(1)
940            .as_any()
941            .downcast_ref::<Int64Array>()
942            .unwrap();
943        assert_eq!(counts.value(0), 2); // null not counted
944    }
945
946    #[tokio::test]
947    async fn test_max_min() {
948        let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 5.0]);
949        let input_max = make_memory_exec(batch.clone());
950        let input_min = make_memory_exec(batch);
951
952        let result_max = execute_fold(
953            input_max,
954            vec![0],
955            vec![FoldBinding {
956                output_name: "mx".to_string(),
957                name: SmolStr::new_static("MAX"),
958                aggregate: builtin_agg("MAX"),
959                input_col_index: 1,
960                input_col_name: None,
961            }],
962        )
963        .await;
964        let result_min = execute_fold(
965            input_min,
966            vec![0],
967            vec![FoldBinding {
968                output_name: "mn".to_string(),
969                name: SmolStr::new_static("MIN"),
970                aggregate: builtin_agg("MIN"),
971                input_col_index: 1,
972                input_col_name: None,
973            }],
974        )
975        .await;
976
977        let max_vals = result_max
978            .column(1)
979            .as_any()
980            .downcast_ref::<Float64Array>()
981            .unwrap();
982        assert_eq!(max_vals.value(0), 5.0);
983
984        let min_vals = result_min
985            .column(1)
986            .as_any()
987            .downcast_ref::<Float64Array>()
988            .unwrap();
989        assert_eq!(min_vals.value(0), 1.0);
990    }
991
992    #[tokio::test]
993    async fn test_avg() {
994        let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![2.0, 4.0, 6.0, 8.0]);
995        let input = make_memory_exec(batch);
996        let result = execute_fold(
997            input,
998            vec![0],
999            vec![FoldBinding {
1000                output_name: "average".to_string(),
1001                name: SmolStr::new_static("AVG"),
1002                aggregate: builtin_agg("AVG"),
1003                input_col_index: 1,
1004                input_col_name: None,
1005            }],
1006        )
1007        .await;
1008
1009        assert_eq!(result.num_rows(), 1);
1010        let avgs = result
1011            .column(1)
1012            .as_any()
1013            .downcast_ref::<Float64Array>()
1014            .unwrap();
1015        assert!((avgs.value(0) - 5.0).abs() < f64::EPSILON);
1016    }
1017
1018    #[tokio::test]
1019    async fn test_multiple_groups() {
1020        let batch = make_test_batch(
1021            vec!["a", "a", "b", "b", "b"],
1022            vec![1.0, 2.0, 10.0, 20.0, 30.0],
1023        );
1024        let input = make_memory_exec(batch);
1025        let result = execute_fold(
1026            input,
1027            vec![0],
1028            vec![FoldBinding {
1029                output_name: "total".to_string(),
1030                name: SmolStr::new_static("SUM"),
1031                aggregate: builtin_agg("SUM"),
1032                input_col_index: 1,
1033                input_col_name: None,
1034            }],
1035        )
1036        .await;
1037
1038        assert_eq!(result.num_rows(), 2);
1039        let names = result
1040            .column(0)
1041            .as_any()
1042            .downcast_ref::<StringArray>()
1043            .unwrap();
1044        let totals = result
1045            .column(1)
1046            .as_any()
1047            .downcast_ref::<Float64Array>()
1048            .unwrap();
1049
1050        for i in 0..2 {
1051            match names.value(i) {
1052                "a" => assert!((totals.value(i) - 3.0).abs() < f64::EPSILON),
1053                "b" => assert!((totals.value(i) - 60.0).abs() < f64::EPSILON),
1054                _ => panic!("unexpected name"),
1055            }
1056        }
1057    }
1058
1059    #[tokio::test]
1060    async fn test_empty_input() {
1061        let schema = Arc::new(Schema::new(vec![
1062            Field::new("name", DataType::Utf8, true),
1063            Field::new("value", DataType::Float64, true),
1064        ]));
1065        let batch = RecordBatch::new_empty(schema);
1066        let input = make_memory_exec(batch);
1067        let result = execute_fold(
1068            input,
1069            vec![0],
1070            vec![FoldBinding {
1071                output_name: "total".to_string(),
1072                name: SmolStr::new_static("SUM"),
1073                aggregate: builtin_agg("SUM"),
1074                input_col_index: 1,
1075                input_col_name: None,
1076            }],
1077        )
1078        .await;
1079
1080        assert_eq!(result.num_rows(), 0);
1081    }
1082
1083    #[tokio::test]
1084    async fn test_multiple_bindings() {
1085        let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
1086        let input = make_memory_exec(batch);
1087        let result = execute_fold(
1088            input,
1089            vec![0],
1090            vec![
1091                FoldBinding {
1092                    output_name: "total".to_string(),
1093                    name: SmolStr::new_static("SUM"),
1094                    aggregate: builtin_agg("SUM"),
1095                    input_col_index: 1,
1096                    input_col_name: None,
1097                },
1098                FoldBinding {
1099                    output_name: "cnt".to_string(),
1100                    name: SmolStr::new_static("COUNT"),
1101                    aggregate: builtin_agg("COUNT"),
1102                    input_col_index: 1,
1103                    input_col_name: None,
1104                },
1105                FoldBinding {
1106                    output_name: "mx".to_string(),
1107                    name: SmolStr::new_static("MAX"),
1108                    aggregate: builtin_agg("MAX"),
1109                    input_col_index: 1,
1110                    input_col_name: None,
1111                },
1112            ],
1113        )
1114        .await;
1115
1116        assert_eq!(result.num_rows(), 1);
1117        assert_eq!(result.num_columns(), 4); // name + total + cnt + mx
1118
1119        let totals = result
1120            .column(1)
1121            .as_any()
1122            .downcast_ref::<Float64Array>()
1123            .unwrap();
1124        assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
1125
1126        let counts = result
1127            .column(2)
1128            .as_any()
1129            .downcast_ref::<Int64Array>()
1130            .unwrap();
1131        assert_eq!(counts.value(0), 3);
1132
1133        let maxes = result
1134            .column(3)
1135            .as_any()
1136            .downcast_ref::<Float64Array>()
1137            .unwrap();
1138        assert_eq!(maxes.value(0), 3.0);
1139    }
1140
1141    // ── MNOR tests ───────────────────────────────────────────────────────
1142
1143    #[tokio::test]
1144    async fn test_nor_single_group() {
1145        // MNOR({0.3, 0.5}) = 1 - (1-0.3)*(1-0.5) = 1 - 0.7*0.5 = 1 - 0.35 = 0.65
1146        let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
1147        let input = make_memory_exec(batch);
1148        let result = execute_fold(
1149            input,
1150            vec![0],
1151            vec![FoldBinding {
1152                output_name: "prob".to_string(),
1153                name: SmolStr::new_static("MNOR"),
1154                aggregate: builtin_agg("MNOR"),
1155                input_col_index: 1,
1156                input_col_name: None,
1157            }],
1158        )
1159        .await;
1160
1161        assert_eq!(result.num_rows(), 1);
1162        let vals = result
1163            .column(1)
1164            .as_any()
1165            .downcast_ref::<Float64Array>()
1166            .unwrap();
1167        assert!((vals.value(0) - 0.65).abs() < 1e-10);
1168    }
1169
1170    #[tokio::test]
1171    async fn test_nor_identity() {
1172        // MNOR({0.0, 0.0}) = 1 - (1-0)*(1-0) = 1 - 1 = 0.0
1173        let batch = make_test_batch(vec!["a", "a"], vec![0.0, 0.0]);
1174        let input = make_memory_exec(batch);
1175        let result = execute_fold(
1176            input,
1177            vec![0],
1178            vec![FoldBinding {
1179                output_name: "prob".to_string(),
1180                name: SmolStr::new_static("MNOR"),
1181                aggregate: builtin_agg("MNOR"),
1182                input_col_index: 1,
1183                input_col_name: None,
1184            }],
1185        )
1186        .await;
1187
1188        let vals = result
1189            .column(1)
1190            .as_any()
1191            .downcast_ref::<Float64Array>()
1192            .unwrap();
1193        assert!((vals.value(0) - 0.0).abs() < 1e-10);
1194    }
1195
1196    #[tokio::test]
1197    async fn test_nor_clamping() {
1198        // Out-of-range values should be clamped to [0, 1]
1199        let batch = make_test_batch(vec!["a", "a"], vec![-0.5, 1.5]);
1200        let input = make_memory_exec(batch);
1201        let result = execute_fold(
1202            input,
1203            vec![0],
1204            vec![FoldBinding {
1205                output_name: "prob".to_string(),
1206                name: SmolStr::new_static("MNOR"),
1207                aggregate: builtin_agg("MNOR"),
1208                input_col_index: 1,
1209                input_col_name: None,
1210            }],
1211        )
1212        .await;
1213
1214        let vals = result
1215            .column(1)
1216            .as_any()
1217            .downcast_ref::<Float64Array>()
1218            .unwrap();
1219        // Clamped to (0.0, 1.0): MNOR = 1 - (1-0)*(1-1) = 1 - 1*0 = 1.0
1220        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1221    }
1222
1223    #[tokio::test]
1224    async fn test_nor_multiple_groups() {
1225        let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.3, 0.5, 0.1, 0.2]);
1226        let input = make_memory_exec(batch);
1227        let result = execute_fold(
1228            input,
1229            vec![0],
1230            vec![FoldBinding {
1231                output_name: "prob".to_string(),
1232                name: SmolStr::new_static("MNOR"),
1233                aggregate: builtin_agg("MNOR"),
1234                input_col_index: 1,
1235                input_col_name: None,
1236            }],
1237        )
1238        .await;
1239
1240        assert_eq!(result.num_rows(), 2);
1241        let names = result
1242            .column(0)
1243            .as_any()
1244            .downcast_ref::<StringArray>()
1245            .unwrap();
1246        let vals = result
1247            .column(1)
1248            .as_any()
1249            .downcast_ref::<Float64Array>()
1250            .unwrap();
1251
1252        for i in 0..2 {
1253            match names.value(i) {
1254                // MNOR({0.3, 0.5}) = 0.65
1255                "a" => assert!((vals.value(i) - 0.65).abs() < 1e-10),
1256                // MNOR({0.1, 0.2}) = 1 - 0.9*0.8 = 1 - 0.72 = 0.28
1257                "b" => assert!((vals.value(i) - 0.28).abs() < 1e-10),
1258                _ => panic!("unexpected name"),
1259            }
1260        }
1261    }
1262
1263    // ── MPROD tests ──────────────────────────────────────────────────────
1264
1265    #[tokio::test]
1266    async fn test_prod_single_group() {
1267        // MPROD({0.6, 0.8}) = 0.48
1268        let batch = make_test_batch(vec!["a", "a"], vec![0.6, 0.8]);
1269        let input = make_memory_exec(batch);
1270        let result = execute_fold(
1271            input,
1272            vec![0],
1273            vec![FoldBinding {
1274                output_name: "prob".to_string(),
1275                name: SmolStr::new_static("MPROD"),
1276                aggregate: builtin_agg("MPROD"),
1277                input_col_index: 1,
1278                input_col_name: None,
1279            }],
1280        )
1281        .await;
1282
1283        assert_eq!(result.num_rows(), 1);
1284        let vals = result
1285            .column(1)
1286            .as_any()
1287            .downcast_ref::<Float64Array>()
1288            .unwrap();
1289        assert!((vals.value(0) - 0.48).abs() < 1e-10);
1290    }
1291
1292    #[tokio::test]
1293    async fn test_prod_identity() {
1294        // MPROD({1.0, 1.0}) = 1.0
1295        let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0]);
1296        let input = make_memory_exec(batch);
1297        let result = execute_fold(
1298            input,
1299            vec![0],
1300            vec![FoldBinding {
1301                output_name: "prob".to_string(),
1302                name: SmolStr::new_static("MPROD"),
1303                aggregate: builtin_agg("MPROD"),
1304                input_col_index: 1,
1305                input_col_name: None,
1306            }],
1307        )
1308        .await;
1309
1310        let vals = result
1311            .column(1)
1312            .as_any()
1313            .downcast_ref::<Float64Array>()
1314            .unwrap();
1315        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1316    }
1317
1318    #[tokio::test]
1319    async fn test_prod_zero_absorbing() {
1320        // MPROD with 0.0 = 0.0 (zero is absorbing element)
1321        let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.0, 0.8]);
1322        let input = make_memory_exec(batch);
1323        let result = execute_fold(
1324            input,
1325            vec![0],
1326            vec![FoldBinding {
1327                output_name: "prob".to_string(),
1328                name: SmolStr::new_static("MPROD"),
1329                aggregate: builtin_agg("MPROD"),
1330                input_col_index: 1,
1331                input_col_name: None,
1332            }],
1333        )
1334        .await;
1335
1336        let vals = result
1337            .column(1)
1338            .as_any()
1339            .downcast_ref::<Float64Array>()
1340            .unwrap();
1341        assert!((vals.value(0) - 0.0).abs() < 1e-10);
1342    }
1343
1344    #[tokio::test]
1345    async fn test_prod_underflow_protection() {
1346        // 50 × 0.5 ≈ 8.88e-16, should not be exactly 0 thanks to log-space
1347        let names: Vec<&str> = vec!["a"; 50];
1348        let values: Vec<f64> = vec![0.5; 50];
1349        let batch = make_test_batch(names, values);
1350        let input = make_memory_exec(batch);
1351        let result = execute_fold(
1352            input,
1353            vec![0],
1354            vec![FoldBinding {
1355                output_name: "prob".to_string(),
1356                name: SmolStr::new_static("MPROD"),
1357                aggregate: builtin_agg("MPROD"),
1358                input_col_index: 1,
1359                input_col_name: None,
1360            }],
1361        )
1362        .await;
1363
1364        let vals = result
1365            .column(1)
1366            .as_any()
1367            .downcast_ref::<Float64Array>()
1368            .unwrap();
1369        let expected = 0.5_f64.powi(50); // ≈ 8.88e-16
1370        assert!(vals.value(0) > 0.0, "should not underflow to zero");
1371        assert!(
1372            (vals.value(0) - expected).abs() / expected < 1e-6,
1373            "result {} should be close to expected {}",
1374            vals.value(0),
1375            expected
1376        );
1377    }
1378
1379    // ── MNOR/MPROD mathematical correctness tests ───────────────────────
1380
1381    fn make_nullable_test_batch(names: Vec<&str>, values: Vec<Option<f64>>) -> RecordBatch {
1382        let schema = Arc::new(Schema::new(vec![
1383            Field::new("name", DataType::Utf8, true),
1384            Field::new("value", DataType::Float64, true),
1385        ]));
1386        RecordBatch::try_new(
1387            schema,
1388            vec![
1389                Arc::new(StringArray::from(
1390                    names.into_iter().map(Some).collect::<Vec<_>>(),
1391                )),
1392                Arc::new(Float64Array::from(values)),
1393            ],
1394        )
1395        .unwrap()
1396    }
1397
1398    #[tokio::test]
1399    async fn test_nor_single_element() {
1400        // MNOR({0.7}) = 0.7 (n=1 identity)
1401        let batch = make_test_batch(vec!["a"], vec![0.7]);
1402        let input = make_memory_exec(batch);
1403        let result = execute_fold(
1404            input,
1405            vec![0],
1406            vec![FoldBinding {
1407                output_name: "prob".to_string(),
1408                name: SmolStr::new_static("MNOR"),
1409                aggregate: builtin_agg("MNOR"),
1410                input_col_index: 1,
1411                input_col_name: None,
1412            }],
1413        )
1414        .await;
1415        let vals = result
1416            .column(1)
1417            .as_any()
1418            .downcast_ref::<Float64Array>()
1419            .unwrap();
1420        assert!((vals.value(0) - 0.7).abs() < 1e-10);
1421    }
1422
1423    #[tokio::test]
1424    async fn test_prod_single_element() {
1425        // MPROD({0.7}) = 0.7 (n=1 identity)
1426        let batch = make_test_batch(vec!["a"], vec![0.7]);
1427        let input = make_memory_exec(batch);
1428        let result = execute_fold(
1429            input,
1430            vec![0],
1431            vec![FoldBinding {
1432                output_name: "prob".to_string(),
1433                name: SmolStr::new_static("MPROD"),
1434                aggregate: builtin_agg("MPROD"),
1435                input_col_index: 1,
1436                input_col_name: None,
1437            }],
1438        )
1439        .await;
1440        let vals = result
1441            .column(1)
1442            .as_any()
1443            .downcast_ref::<Float64Array>()
1444            .unwrap();
1445        assert!((vals.value(0) - 0.7).abs() < 1e-10);
1446    }
1447
1448    #[tokio::test]
1449    async fn test_nor_three_elements() {
1450        // MNOR({0.3, 0.4, 0.5}) = 1 - (0.7)(0.6)(0.5) = 0.79
1451        let batch = make_test_batch(vec!["a", "a", "a"], vec![0.3, 0.4, 0.5]);
1452        let input = make_memory_exec(batch);
1453        let result = execute_fold(
1454            input,
1455            vec![0],
1456            vec![FoldBinding {
1457                output_name: "prob".to_string(),
1458                name: SmolStr::new_static("MNOR"),
1459                aggregate: builtin_agg("MNOR"),
1460                input_col_index: 1,
1461                input_col_name: None,
1462            }],
1463        )
1464        .await;
1465        let vals = result
1466            .column(1)
1467            .as_any()
1468            .downcast_ref::<Float64Array>()
1469            .unwrap();
1470        assert!((vals.value(0) - 0.79).abs() < 1e-10);
1471    }
1472
1473    #[tokio::test]
1474    async fn test_nor_four_elements_spec_example() {
1475        // Spec §4.5: MNOR({0.72, 0.54, 0.56, 0.42}) = 1 - (0.28)(0.46)(0.44)(0.58) = 0.96713024
1476        let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![0.72, 0.54, 0.56, 0.42]);
1477        let input = make_memory_exec(batch);
1478        let result = execute_fold(
1479            input,
1480            vec![0],
1481            vec![FoldBinding {
1482                output_name: "prob".to_string(),
1483                name: SmolStr::new_static("MNOR"),
1484                aggregate: builtin_agg("MNOR"),
1485                input_col_index: 1,
1486                input_col_name: None,
1487            }],
1488        )
1489        .await;
1490        let vals = result
1491            .column(1)
1492            .as_any()
1493            .downcast_ref::<Float64Array>()
1494            .unwrap();
1495        assert!(
1496            (vals.value(0) - 0.96713024).abs() < 1e-10,
1497            "expected 0.96713024, got {}",
1498            vals.value(0)
1499        );
1500    }
1501
1502    #[tokio::test]
1503    async fn test_prod_three_elements() {
1504        // MPROD({0.5, 0.5, 0.5}) = 0.125
1505        let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.5, 0.5]);
1506        let input = make_memory_exec(batch);
1507        let result = execute_fold(
1508            input,
1509            vec![0],
1510            vec![FoldBinding {
1511                output_name: "prob".to_string(),
1512                name: SmolStr::new_static("MPROD"),
1513                aggregate: builtin_agg("MPROD"),
1514                input_col_index: 1,
1515                input_col_name: None,
1516            }],
1517        )
1518        .await;
1519        let vals = result
1520            .column(1)
1521            .as_any()
1522            .downcast_ref::<Float64Array>()
1523            .unwrap();
1524        assert!((vals.value(0) - 0.125).abs() < 1e-10);
1525    }
1526
1527    #[tokio::test]
1528    async fn test_nor_absorbing_element() {
1529        // p=1.0 absorbs: MNOR({0.3, 1.0}) = 1.0
1530        let batch = make_test_batch(vec!["a", "a"], vec![0.3, 1.0]);
1531        let input = make_memory_exec(batch);
1532        let result = execute_fold(
1533            input,
1534            vec![0],
1535            vec![FoldBinding {
1536                output_name: "prob".to_string(),
1537                name: SmolStr::new_static("MNOR"),
1538                aggregate: builtin_agg("MNOR"),
1539                input_col_index: 1,
1540                input_col_name: None,
1541            }],
1542        )
1543        .await;
1544        let vals = result
1545            .column(1)
1546            .as_any()
1547            .downcast_ref::<Float64Array>()
1548            .unwrap();
1549        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1550    }
1551
1552    #[tokio::test]
1553    async fn test_prod_clamping() {
1554        // Out-of-range 2.0 clamped to 1.0: MPROD({2.0, 0.5}) = 1.0 * 0.5 = 0.5
1555        let batch = make_test_batch(vec!["a", "a"], vec![2.0, 0.5]);
1556        let input = make_memory_exec(batch);
1557        let result = execute_fold(
1558            input,
1559            vec![0],
1560            vec![FoldBinding {
1561                output_name: "prob".to_string(),
1562                name: SmolStr::new_static("MPROD"),
1563                aggregate: builtin_agg("MPROD"),
1564                input_col_index: 1,
1565                input_col_name: None,
1566            }],
1567        )
1568        .await;
1569        let vals = result
1570            .column(1)
1571            .as_any()
1572            .downcast_ref::<Float64Array>()
1573            .unwrap();
1574        assert!((vals.value(0) - 0.5).abs() < 1e-10);
1575    }
1576
1577    #[tokio::test]
1578    async fn test_prod_multiple_groups() {
1579        // a: MPROD({0.6, 0.8}) = 0.48, b: MPROD({0.5, 0.5}) = 0.25
1580        let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.6, 0.8, 0.5, 0.5]);
1581        let input = make_memory_exec(batch);
1582        let result = execute_fold(
1583            input,
1584            vec![0],
1585            vec![FoldBinding {
1586                output_name: "prob".to_string(),
1587                name: SmolStr::new_static("MPROD"),
1588                aggregate: builtin_agg("MPROD"),
1589                input_col_index: 1,
1590                input_col_name: None,
1591            }],
1592        )
1593        .await;
1594
1595        assert_eq!(result.num_rows(), 2);
1596        let names = result
1597            .column(0)
1598            .as_any()
1599            .downcast_ref::<StringArray>()
1600            .unwrap();
1601        let vals = result
1602            .column(1)
1603            .as_any()
1604            .downcast_ref::<Float64Array>()
1605            .unwrap();
1606        for i in 0..2 {
1607            match names.value(i) {
1608                "a" => assert!((vals.value(i) - 0.48).abs() < 1e-10),
1609                "b" => assert!((vals.value(i) - 0.25).abs() < 1e-10),
1610                _ => panic!("unexpected group name"),
1611            }
1612        }
1613    }
1614
1615    #[tokio::test]
1616    async fn test_nor_commutativity() {
1617        // Order independence: MNOR({0.2, 0.5, 0.8}) = MNOR({0.8, 0.5, 0.2}) = 0.92
1618        let fwd = make_test_batch(vec!["a", "a", "a"], vec![0.2, 0.5, 0.8]);
1619        let rev = make_test_batch(vec!["a", "a", "a"], vec![0.8, 0.5, 0.2]);
1620        let binding = vec![FoldBinding {
1621            output_name: "prob".to_string(),
1622            name: SmolStr::new_static("MNOR"),
1623            aggregate: builtin_agg("MNOR"),
1624            input_col_index: 1,
1625            input_col_name: None,
1626        }];
1627        let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1628        let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1629        let v1 = r1
1630            .column(1)
1631            .as_any()
1632            .downcast_ref::<Float64Array>()
1633            .unwrap()
1634            .value(0);
1635        let v2 = r2
1636            .column(1)
1637            .as_any()
1638            .downcast_ref::<Float64Array>()
1639            .unwrap()
1640            .value(0);
1641        assert!((v1 - 0.92).abs() < 1e-10);
1642        assert!((v2 - 0.92).abs() < 1e-10);
1643        assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1644    }
1645
1646    #[tokio::test]
1647    async fn test_prod_commutativity() {
1648        // Order independence: MPROD({0.5, 0.25}) = MPROD({0.25, 0.5}) = 0.125
1649        let fwd = make_test_batch(vec!["a", "a"], vec![0.5, 0.25]);
1650        let rev = make_test_batch(vec!["a", "a"], vec![0.25, 0.5]);
1651        let binding = vec![FoldBinding {
1652            output_name: "prob".to_string(),
1653            name: SmolStr::new_static("MPROD"),
1654            aggregate: builtin_agg("MPROD"),
1655            input_col_index: 1,
1656            input_col_name: None,
1657        }];
1658        let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1659        let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1660        let v1 = r1
1661            .column(1)
1662            .as_any()
1663            .downcast_ref::<Float64Array>()
1664            .unwrap()
1665            .value(0);
1666        let v2 = r2
1667            .column(1)
1668            .as_any()
1669            .downcast_ref::<Float64Array>()
1670            .unwrap()
1671            .value(0);
1672        assert!((v1 - 0.125).abs() < 1e-10);
1673        assert!((v2 - 0.125).abs() < 1e-10);
1674        assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1675    }
1676
1677    #[tokio::test]
1678    async fn test_nor_boundary_near_zero() {
1679        // Precision near 0: MNOR({0.001, 0.002}) = 1 - (0.999)(0.998) = 0.002998
1680        let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1681        let input = make_memory_exec(batch);
1682        let result = execute_fold(
1683            input,
1684            vec![0],
1685            vec![FoldBinding {
1686                output_name: "prob".to_string(),
1687                name: SmolStr::new_static("MNOR"),
1688                aggregate: builtin_agg("MNOR"),
1689                input_col_index: 1,
1690                input_col_name: None,
1691            }],
1692        )
1693        .await;
1694        let vals = result
1695            .column(1)
1696            .as_any()
1697            .downcast_ref::<Float64Array>()
1698            .unwrap();
1699        let expected = 1.0 - 0.999 * 0.998;
1700        assert!(
1701            (vals.value(0) - expected).abs() < 1e-10,
1702            "expected {}, got {}",
1703            expected,
1704            vals.value(0)
1705        );
1706    }
1707
1708    #[tokio::test]
1709    async fn test_nor_boundary_near_one() {
1710        // Precision near 1: MNOR({0.999, 0.998}) = 1 - (0.001)(0.002) = 0.999998
1711        let batch = make_test_batch(vec!["a", "a"], vec![0.999, 0.998]);
1712        let input = make_memory_exec(batch);
1713        let result = execute_fold(
1714            input,
1715            vec![0],
1716            vec![FoldBinding {
1717                output_name: "prob".to_string(),
1718                name: SmolStr::new_static("MNOR"),
1719                aggregate: builtin_agg("MNOR"),
1720                input_col_index: 1,
1721                input_col_name: None,
1722            }],
1723        )
1724        .await;
1725        let vals = result
1726            .column(1)
1727            .as_any()
1728            .downcast_ref::<Float64Array>()
1729            .unwrap();
1730        let expected = 1.0 - 0.001 * 0.002;
1731        assert!(
1732            (vals.value(0) - expected).abs() < 1e-10,
1733            "expected {}, got {}",
1734            expected,
1735            vals.value(0)
1736        );
1737    }
1738
1739    #[tokio::test]
1740    async fn test_prod_boundary_near_zero() {
1741        // Precision near 0: MPROD({0.001, 0.002}) = 2e-6
1742        let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1743        let input = make_memory_exec(batch);
1744        let result = execute_fold(
1745            input,
1746            vec![0],
1747            vec![FoldBinding {
1748                output_name: "prob".to_string(),
1749                name: SmolStr::new_static("MPROD"),
1750                aggregate: builtin_agg("MPROD"),
1751                input_col_index: 1,
1752                input_col_name: None,
1753            }],
1754        )
1755        .await;
1756        let vals = result
1757            .column(1)
1758            .as_any()
1759            .downcast_ref::<Float64Array>()
1760            .unwrap();
1761        assert!(
1762            (vals.value(0) - 2e-6).abs() < 1e-15,
1763            "expected 2e-6, got {}",
1764            vals.value(0)
1765        );
1766    }
1767
1768    #[tokio::test]
1769    async fn test_nor_empty_input() {
1770        // Empty input → 0 rows output
1771        let schema = Arc::new(Schema::new(vec![
1772            Field::new("name", DataType::Utf8, true),
1773            Field::new("value", DataType::Float64, true),
1774        ]));
1775        let batch = RecordBatch::new_empty(schema);
1776        let input = make_memory_exec(batch);
1777        let result = execute_fold(
1778            input,
1779            vec![0],
1780            vec![FoldBinding {
1781                output_name: "prob".to_string(),
1782                name: SmolStr::new_static("MNOR"),
1783                aggregate: builtin_agg("MNOR"),
1784                input_col_index: 1,
1785                input_col_name: None,
1786            }],
1787        )
1788        .await;
1789        assert_eq!(result.num_rows(), 0);
1790    }
1791
1792    #[tokio::test]
1793    async fn test_nor_nan_handling() {
1794        // NaN propagates through noisy-OR
1795        let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::NAN]);
1796        let input = make_memory_exec(batch);
1797        let result = execute_fold(
1798            input,
1799            vec![0],
1800            vec![FoldBinding {
1801                output_name: "prob".to_string(),
1802                name: SmolStr::new_static("MNOR"),
1803                aggregate: builtin_agg("MNOR"),
1804                input_col_index: 1,
1805                input_col_name: None,
1806            }],
1807        )
1808        .await;
1809        let vals = result
1810            .column(1)
1811            .as_any()
1812            .downcast_ref::<Float64Array>()
1813            .unwrap();
1814        assert!(vals.value(0).is_nan(), "NaN should propagate through MNOR");
1815    }
1816
1817    #[tokio::test]
1818    async fn test_prod_nan_handling() {
1819        // NaN propagates through product
1820        let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::NAN]);
1821        let input = make_memory_exec(batch);
1822        let result = execute_fold(
1823            input,
1824            vec![0],
1825            vec![FoldBinding {
1826                output_name: "prob".to_string(),
1827                name: SmolStr::new_static("MPROD"),
1828                aggregate: builtin_agg("MPROD"),
1829                input_col_index: 1,
1830                input_col_name: None,
1831            }],
1832        )
1833        .await;
1834        let vals = result
1835            .column(1)
1836            .as_any()
1837            .downcast_ref::<Float64Array>()
1838            .unwrap();
1839        assert!(vals.value(0).is_nan(), "NaN should propagate through MPROD");
1840    }
1841
1842    #[tokio::test]
1843    async fn test_prod_infinity_handling() {
1844        // +∞ clamped to 1.0: MPROD({0.5, ∞}) = 0.5 * 1.0 = 0.5
1845        let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::INFINITY]);
1846        let input = make_memory_exec(batch);
1847        let result = execute_fold(
1848            input,
1849            vec![0],
1850            vec![FoldBinding {
1851                output_name: "prob".to_string(),
1852                name: SmolStr::new_static("MPROD"),
1853                aggregate: builtin_agg("MPROD"),
1854                input_col_index: 1,
1855                input_col_name: None,
1856            }],
1857        )
1858        .await;
1859        let vals = result
1860            .column(1)
1861            .as_any()
1862            .downcast_ref::<Float64Array>()
1863            .unwrap();
1864        assert!((vals.value(0) - 0.5).abs() < 1e-10);
1865    }
1866
1867    #[tokio::test]
1868    async fn test_nor_infinity_handling() {
1869        // +∞ clamped to 1.0, which absorbs: MNOR({0.3, ∞}) = 1.0
1870        let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::INFINITY]);
1871        let input = make_memory_exec(batch);
1872        let result = execute_fold(
1873            input,
1874            vec![0],
1875            vec![FoldBinding {
1876                output_name: "prob".to_string(),
1877                name: SmolStr::new_static("MNOR"),
1878                aggregate: builtin_agg("MNOR"),
1879                input_col_index: 1,
1880                input_col_name: None,
1881            }],
1882        )
1883        .await;
1884        let vals = result
1885            .column(1)
1886            .as_any()
1887            .downcast_ref::<Float64Array>()
1888            .unwrap();
1889        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1890    }
1891
1892    #[tokio::test]
1893    async fn test_nor_all_null_values() {
1894        // All-null input → null output
1895        let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1896        let input = make_memory_exec(batch);
1897        let result = execute_fold(
1898            input,
1899            vec![0],
1900            vec![FoldBinding {
1901                output_name: "prob".to_string(),
1902                name: SmolStr::new_static("MNOR"),
1903                aggregate: builtin_agg("MNOR"),
1904                input_col_index: 1,
1905                input_col_name: None,
1906            }],
1907        )
1908        .await;
1909        assert_eq!(result.num_rows(), 1);
1910        let vals = result
1911            .column(1)
1912            .as_any()
1913            .downcast_ref::<Float64Array>()
1914            .unwrap();
1915        assert!(vals.is_null(0), "all-null MNOR should produce null");
1916    }
1917
1918    #[tokio::test]
1919    async fn test_prod_all_null_values() {
1920        // All-null input → null output
1921        let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1922        let input = make_memory_exec(batch);
1923        let result = execute_fold(
1924            input,
1925            vec![0],
1926            vec![FoldBinding {
1927                output_name: "prob".to_string(),
1928                name: SmolStr::new_static("MPROD"),
1929                aggregate: builtin_agg("MPROD"),
1930                input_col_index: 1,
1931                input_col_name: None,
1932            }],
1933        )
1934        .await;
1935        assert_eq!(result.num_rows(), 1);
1936        let vals = result
1937            .column(1)
1938            .as_any()
1939            .downcast_ref::<Float64Array>()
1940            .unwrap();
1941        assert!(vals.is_null(0), "all-null MPROD should produce null");
1942    }
1943
1944    #[tokio::test]
1945    async fn test_nor_mixed_null_values() {
1946        // Nulls skipped: MNOR({0.3, null, 0.5}) = 1 - (0.7)(0.5) = 0.65
1947        let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.3), None, Some(0.5)]);
1948        let input = make_memory_exec(batch);
1949        let result = execute_fold(
1950            input,
1951            vec![0],
1952            vec![FoldBinding {
1953                output_name: "prob".to_string(),
1954                name: SmolStr::new_static("MNOR"),
1955                aggregate: builtin_agg("MNOR"),
1956                input_col_index: 1,
1957                input_col_name: None,
1958            }],
1959        )
1960        .await;
1961        let vals = result
1962            .column(1)
1963            .as_any()
1964            .downcast_ref::<Float64Array>()
1965            .unwrap();
1966        assert!((vals.value(0) - 0.65).abs() < 1e-10);
1967    }
1968
1969    #[tokio::test]
1970    async fn test_prod_mixed_null_values() {
1971        // Nulls skipped: MPROD({0.6, null, 0.8}) = 0.6 * 0.8 = 0.48
1972        let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.6), None, Some(0.8)]);
1973        let input = make_memory_exec(batch);
1974        let result = execute_fold(
1975            input,
1976            vec![0],
1977            vec![FoldBinding {
1978                output_name: "prob".to_string(),
1979                name: SmolStr::new_static("MPROD"),
1980                aggregate: builtin_agg("MPROD"),
1981                input_col_index: 1,
1982                input_col_name: None,
1983            }],
1984        )
1985        .await;
1986        let vals = result
1987            .column(1)
1988            .as_any()
1989            .downcast_ref::<Float64Array>()
1990            .unwrap();
1991        assert!((vals.value(0) - 0.48).abs() < 1e-10);
1992    }
1993
1994    #[tokio::test]
1995    async fn test_nor_many_small_values() {
1996        // Large accumulation: 20 × 0.1 → 1 - 0.9^20 ≈ 0.8784
1997        let names: Vec<&str> = vec!["a"; 20];
1998        let values: Vec<f64> = vec![0.1; 20];
1999        let batch = make_test_batch(names, values);
2000        let input = make_memory_exec(batch);
2001        let result = execute_fold(
2002            input,
2003            vec![0],
2004            vec![FoldBinding {
2005                output_name: "prob".to_string(),
2006                name: SmolStr::new_static("MNOR"),
2007                aggregate: builtin_agg("MNOR"),
2008                input_col_index: 1,
2009                input_col_name: None,
2010            }],
2011        )
2012        .await;
2013        let vals = result
2014            .column(1)
2015            .as_any()
2016            .downcast_ref::<Float64Array>()
2017            .unwrap();
2018        let expected = 1.0 - 0.9_f64.powi(20);
2019        assert!(
2020            (vals.value(0) - expected).abs() < 1e-10,
2021            "expected {}, got {}",
2022            expected,
2023            vals.value(0)
2024        );
2025    }
2026
2027    // ── Aggregate-trait classification tests ──────────────────────────────
2028
2029    #[test]
2030    fn trait_dispatch_monotonicity() {
2031        for name in [
2032            "SUM", "MAX", "MIN", "COUNT", "AVG", "COLLECT", "MNOR", "MPROD",
2033        ] {
2034            let agg = builtin_agg(name);
2035            let sl = agg.semilattice();
2036            // MIN/MAX/MNOR/MPROD/COLLECT/COUNT are monotone; SUM/AVG are not.
2037            let expect_monotone =
2038                matches!(name, "MIN" | "MAX" | "MNOR" | "MPROD" | "COLLECT" | "COUNT");
2039            assert_eq!(
2040                sl.monotone_join, expect_monotone,
2041                "monotone_join mismatch for {name}"
2042            );
2043        }
2044    }
2045
2046    #[test]
2047    fn trait_dispatch_initial_accumulator() {
2048        // The row-level fast path uses `initial_accum_f64()`.
2049        assert_eq!(builtin_agg("SUM").initial_accum_f64(), Some(0.0));
2050        assert_eq!(builtin_agg("COUNT").initial_accum_f64(), Some(0.0));
2051        assert_eq!(builtin_agg("MNOR").initial_accum_f64(), Some(0.0));
2052        assert_eq!(
2053            builtin_agg("MAX").initial_accum_f64(),
2054            Some(f64::NEG_INFINITY)
2055        );
2056        assert_eq!(builtin_agg("MIN").initial_accum_f64(), Some(f64::INFINITY));
2057        assert_eq!(builtin_agg("MPROD").initial_accum_f64(), Some(1.0));
2058        // AVG and COLLECT have no row-level fast path — return None.
2059        assert_eq!(builtin_agg("AVG").initial_accum_f64(), None);
2060        assert_eq!(builtin_agg("COLLECT").initial_accum_f64(), None);
2061    }
2062
2063    #[test]
2064    fn trait_dispatch_probability_predicate() {
2065        // is_probability_aggregate is the trait predicate for probability-domain aggregates.
2066        for name in ["MNOR", "MPROD"] {
2067            assert!(
2068                builtin_agg(name).is_probability_aggregate(),
2069                "expected {name} to be probability-domain"
2070            );
2071        }
2072        for name in ["SUM", "MAX", "MIN", "COUNT", "AVG", "COLLECT"] {
2073            assert!(
2074                !builtin_agg(name).is_probability_aggregate(),
2075                "{name} should NOT be probability-domain"
2076            );
2077        }
2078    }
2079
2080    #[test]
2081    fn trait_dispatch_noisy_or_predicate() {
2082        // is_noisy_or distinguishes MNOR from MPROD for semiring-op selection.
2083        assert!(builtin_agg("MNOR").is_noisy_or());
2084        assert!(!builtin_agg("MPROD").is_noisy_or());
2085    }
2086
2087    // ── Strict mode tests (Phase 5) ──────────────────────────────────────
2088
2089    async fn execute_fold_strict(
2090        input: Arc<dyn ExecutionPlan>,
2091        key_indices: Vec<usize>,
2092        fold_bindings: Vec<FoldBinding>,
2093        strict: bool,
2094    ) -> DFResult<RecordBatch> {
2095        let exec = FoldExec::new(input, key_indices, fold_bindings, strict, 1e-15);
2096        let ctx = SessionContext::new();
2097        let task_ctx = ctx.task_ctx();
2098        let stream = exec.execute(0, task_ctx).unwrap();
2099        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream).await?;
2100        if batches.is_empty() {
2101            Ok(RecordBatch::new_empty(exec.schema()))
2102        } else {
2103            arrow::compute::concat_batches(&exec.schema(), &batches).map_err(arrow_err)
2104        }
2105    }
2106
2107    #[tokio::test]
2108    async fn test_nor_strict_rejects_above_one() {
2109        let batch = make_test_batch(vec!["a"], vec![1.5]);
2110        let input = make_memory_exec(batch);
2111        let result = execute_fold_strict(
2112            input,
2113            vec![0],
2114            vec![FoldBinding {
2115                output_name: "p".into(),
2116                name: SmolStr::new_static("MNOR"),
2117                aggregate: builtin_agg("MNOR"),
2118                input_col_index: 1,
2119                input_col_name: None,
2120            }],
2121            true,
2122        )
2123        .await;
2124        assert!(result.is_err());
2125        let err = result.unwrap_err().to_string();
2126        assert!(
2127            err.contains("strict_probability_domain"),
2128            "Expected strict error, got: {}",
2129            err
2130        );
2131    }
2132
2133    #[tokio::test]
2134    async fn test_nor_strict_rejects_negative() {
2135        let batch = make_test_batch(vec!["a"], vec![-0.1]);
2136        let input = make_memory_exec(batch);
2137        let result = execute_fold_strict(
2138            input,
2139            vec![0],
2140            vec![FoldBinding {
2141                output_name: "p".into(),
2142                name: SmolStr::new_static("MNOR"),
2143                aggregate: builtin_agg("MNOR"),
2144                input_col_index: 1,
2145                input_col_name: None,
2146            }],
2147            true,
2148        )
2149        .await;
2150        assert!(result.is_err());
2151        let err = result.unwrap_err().to_string();
2152        assert!(
2153            err.contains("strict_probability_domain"),
2154            "Expected strict error, got: {}",
2155            err
2156        );
2157    }
2158
2159    #[tokio::test]
2160    async fn test_prod_strict_rejects_above_one() {
2161        let batch = make_test_batch(vec!["a"], vec![2.0]);
2162        let input = make_memory_exec(batch);
2163        let result = execute_fold_strict(
2164            input,
2165            vec![0],
2166            vec![FoldBinding {
2167                output_name: "p".into(),
2168                name: SmolStr::new_static("MPROD"),
2169                aggregate: builtin_agg("MPROD"),
2170                input_col_index: 1,
2171                input_col_name: None,
2172            }],
2173            true,
2174        )
2175        .await;
2176        assert!(result.is_err());
2177        let err = result.unwrap_err().to_string();
2178        assert!(
2179            err.contains("strict_probability_domain"),
2180            "Expected strict error, got: {}",
2181            err
2182        );
2183    }
2184
2185    #[tokio::test]
2186    async fn test_prod_strict_rejects_negative() {
2187        let batch = make_test_batch(vec!["a"], vec![-0.5]);
2188        let input = make_memory_exec(batch);
2189        let result = execute_fold_strict(
2190            input,
2191            vec![0],
2192            vec![FoldBinding {
2193                output_name: "p".into(),
2194                name: SmolStr::new_static("MPROD"),
2195                aggregate: builtin_agg("MPROD"),
2196                input_col_index: 1,
2197                input_col_name: None,
2198            }],
2199            true,
2200        )
2201        .await;
2202        assert!(result.is_err());
2203        let err = result.unwrap_err().to_string();
2204        assert!(
2205            err.contains("strict_probability_domain"),
2206            "Expected strict error, got: {}",
2207            err
2208        );
2209    }
2210
2211    #[tokio::test]
2212    async fn test_nor_strict_accepts_valid() {
2213        let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
2214        let input = make_memory_exec(batch);
2215        let result = execute_fold_strict(
2216            input,
2217            vec![0],
2218            vec![FoldBinding {
2219                output_name: "p".into(),
2220                name: SmolStr::new_static("MNOR"),
2221                aggregate: builtin_agg("MNOR"),
2222                input_col_index: 1,
2223                input_col_name: None,
2224            }],
2225            true,
2226        )
2227        .await;
2228        assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
2229        let batch = result.unwrap();
2230        let vals = batch
2231            .column(1)
2232            .as_any()
2233            .downcast_ref::<Float64Array>()
2234            .unwrap();
2235        let expected = 0.65; // 1 - (1-0.3)(1-0.5)
2236        assert!(
2237            (vals.value(0) - expected).abs() < 1e-10,
2238            "expected {}, got {}",
2239            expected,
2240            vals.value(0)
2241        );
2242    }
2243
2244    #[tokio::test]
2245    async fn test_count_all_groups_by_key() {
2246        // Two groups: "a" (2 rows), "b" (1 row)
2247        let batch = make_test_batch(vec!["a", "a", "b"], vec![10.0, 20.0, 30.0]);
2248        let input = make_memory_exec(batch);
2249        let result = execute_fold(
2250            input,
2251            vec![0],
2252            vec![FoldBinding {
2253                output_name: "cnt".to_string(),
2254                name: SmolStr::new_static("COUNTALL"),
2255                aggregate: builtin_agg("COUNTALL"),
2256                input_col_index: 0, // unused for CountAll
2257                input_col_name: None,
2258            }],
2259        )
2260        .await;
2261
2262        assert_eq!(result.num_rows(), 2, "Should have 2 groups");
2263        let counts = result
2264            .column(1)
2265            .as_any()
2266            .downcast_ref::<Int64Array>()
2267            .unwrap();
2268        assert_eq!(counts.value(0), 2, "Group 'a' should have count 2");
2269        assert_eq!(counts.value(1), 1, "Group 'b' should have count 1");
2270    }
2271
2272    // ── Registry-resolve sanity tests ────────────────────────────────────
2273
2274    /// A test-only `LocyAggregate` that's not in `uni-plugin-builtin`.
2275    /// Used to prove `resolve_locy_aggregate` walks the registry rather
2276    /// than dispatching from a hardcoded built-in table.
2277    #[derive(Debug)]
2278    struct IdentityAgg;
2279
2280    impl LocyAggregate for IdentityAgg {
2281        fn semilattice(&self) -> uni_plugin::traits::locy::Semilattice {
2282            uni_plugin::traits::locy::Semilattice::BOUNDED_MIN_MAX
2283        }
2284        fn output_type(&self) -> arrow_schema::DataType {
2285            arrow_schema::DataType::Float64
2286        }
2287        fn create(&self) -> Box<dyn uni_plugin::traits::locy::LocyAggState> {
2288            panic!("IdentityAgg::create not used in this sanity test")
2289        }
2290    }
2291
2292    #[test]
2293    fn resolve_locy_aggregate_returns_registered_instance() {
2294        let registry = uni_plugin::PluginRegistry::new();
2295        let plugin_id = uni_plugin::PluginId::new(uni_plugin::QName::BUILTIN_NS);
2296        let caps = uni_plugin::CapabilitySet::from_iter_of([uni_plugin::Capability::LocyAggregate]);
2297
2298        let registered: Arc<dyn LocyAggregate> = Arc::new(IdentityAgg);
2299        let mut r = uni_plugin::PluginRegistrar::new(plugin_id, &caps, &registry);
2300        r.locy_aggregate(
2301            uni_plugin::QName::builtin("TEST_IDENTITY"),
2302            Arc::clone(&registered),
2303        )
2304        .expect("register");
2305        r.commit_to_registry().expect("commit");
2306
2307        let resolved = resolve_locy_aggregate(&registry, "TEST_IDENTITY")
2308            .expect("registered aggregate should resolve");
2309        assert!(
2310            Arc::ptr_eq(&registered, &resolved.aggregate),
2311            "registry must return the exact Arc that was registered"
2312        );
2313
2314        // Unknown name still returns None — the resolver does not fall back.
2315        assert!(resolve_locy_aggregate(&registry, "NOT_REGISTERED").is_none());
2316    }
2317
2318    #[test]
2319    fn default_locy_plugin_registry_contains_all_builtins() {
2320        let r = default_locy_plugin_registry();
2321        for name in [
2322            "MIN", "MAX", "SUM", "MSUM", "COUNT", "AVG", "COLLECT", "MNOR", "MPROD",
2323        ] {
2324            assert!(
2325                resolve_locy_aggregate(&r, name).is_some(),
2326                "default registry should contain built-in `{name}`"
2327            );
2328        }
2329    }
2330
2331    // ── User-defined aggregate runs through the trait (G1 regression) ─────
2332
2333    /// A novel aggregate not in `uni-plugin-builtin`: per-group `max − min`.
2334    ///
2335    /// Requires real columnar state (two accumulators), so it is *not*
2336    /// expressible via the `update_step` scalar fast path. Before the fold
2337    /// executor dispatched through [`LocyAggState`], a binding like this hit
2338    /// the closed-enum `_ => Err("unsupported aggregate")` arm at runtime;
2339    /// now it executes through `create`/`ingest_indices`/`finalize`.
2340    #[derive(Debug)]
2341    struct RangeAgg;
2342
2343    impl LocyAggregate for RangeAgg {
2344        fn semilattice(&self) -> uni_plugin::traits::locy::Semilattice {
2345            uni_plugin::traits::locy::Semilattice::NON_MONOTONE
2346        }
2347        fn output_type(&self) -> DataType {
2348            DataType::Float64
2349        }
2350        fn create(&self) -> Box<dyn uni_plugin::traits::locy::LocyAggState> {
2351            Box::new(RangeState {
2352                min: None,
2353                max: None,
2354            })
2355        }
2356    }
2357
2358    #[derive(Debug)]
2359    struct RangeState {
2360        min: Option<f64>,
2361        max: Option<f64>,
2362    }
2363
2364    impl uni_plugin::traits::locy::LocyAggState for RangeState {
2365        fn as_any(&self) -> &dyn std::any::Any {
2366            self
2367        }
2368        fn ingest_indices(
2369            &mut self,
2370            col: &dyn Array,
2371            indices: &[usize],
2372            _cx: &FoldContext,
2373        ) -> Result<(), uni_plugin::FnError> {
2374            let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
2375            for &i in indices {
2376                if arr.is_null(i) {
2377                    continue;
2378                }
2379                let v = arr.value(i);
2380                self.min = Some(self.min.map_or(v, |m| m.min(v)));
2381                self.max = Some(self.max.map_or(v, |m| m.max(v)));
2382            }
2383            Ok(())
2384        }
2385        fn merge(
2386            &mut self,
2387            _other: &dyn uni_plugin::traits::locy::LocyAggState,
2388        ) -> Result<(), uni_plugin::FnError> {
2389            Ok(())
2390        }
2391        fn finalize(&self) -> Result<ScalarValue, uni_plugin::FnError> {
2392            match (self.min, self.max) {
2393                (Some(lo), Some(hi)) => Ok(ScalarValue::Float64(Some(hi - lo))),
2394                _ => Ok(ScalarValue::Float64(None)),
2395            }
2396        }
2397    }
2398
2399    #[tokio::test]
2400    async fn user_defined_aggregate_runs_in_non_recursive_fold() {
2401        // group "a": [1.0, 5.0] → range 4.0 ; group "b": [3.0] → range 0.0
2402        let batch = make_test_batch(vec!["a", "a", "b"], vec![1.0, 5.0, 3.0]);
2403        let input = make_memory_exec(batch);
2404        let binding = FoldBinding {
2405            output_name: "r".into(),
2406            name: SmolStr::new_static("RANGE"),
2407            aggregate: Arc::new(RangeAgg),
2408            input_col_index: 1,
2409            input_col_name: Some("value".to_string()),
2410        };
2411        let out = execute_fold(input, vec![0], vec![binding]).await;
2412        assert_eq!(out.num_rows(), 2);
2413        let col = out
2414            .column(1)
2415            .as_any()
2416            .downcast_ref::<Float64Array>()
2417            .expect("range output is Float64");
2418        // ordered_keys preserves first-seen order: "a" then "b".
2419        assert_eq!(col.value(0), 4.0);
2420        assert_eq!(col.value(1), 0.0);
2421    }
2422}