Skip to main content

uni_query/query/df_graph/
locy_fold.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! FOLD operator for Locy.
5//!
6//! `FoldExec` applies fold (lattice-join) semantics: for each group of rows sharing
7//! the same KEY columns, it reduces non-key columns via their declared fold functions.
8
9use crate::query::df_graph::common::{
10    ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
11};
12use arrow_array::builder::{Float64Builder, Int64Builder, LargeBinaryBuilder};
13use arrow_array::{Array, Float64Array, Int64Array, RecordBatch};
14use arrow_schema::{DataType, Field, Schema, SchemaRef};
15use datafusion::common::Result as DFResult;
16use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
17use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
18use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
19use futures::{Stream, TryStreamExt};
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26
27/// Direction of monotonicity for a fold aggregate.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MonotonicDirection {
30    /// Value can only stay the same or increase across iterations.
31    NonDecreasing,
32    /// Value can only stay the same or decrease across iterations.
33    NonIncreasing,
34}
35
36/// Aggregate function kind for FOLD bindings.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum FoldAggKind {
39    Sum,
40    Max,
41    Min,
42    Count,
43    /// Count all rows in a group (like SQL `COUNT(*)`), ignoring nulls.
44    CountAll,
45    Avg,
46    Collect,
47    Nor,  // Noisy-OR: 1 − ∏(1 − pᵢ)
48    Prod, // Product:  ∏ pᵢ
49}
50
51impl FoldAggKind {
52    /// Returns `true` if this aggregate is monotonic (safe for fixpoint iteration).
53    pub fn is_monotonic(&self) -> bool {
54        matches!(
55            self,
56            Self::Sum
57                | Self::Max
58                | Self::Min
59                | Self::Count
60                | Self::CountAll
61                | Self::Nor
62                | Self::Prod
63        )
64    }
65
66    /// Returns the monotonicity direction, or `None` for non-monotonic aggregates.
67    pub fn monotonicity_direction(&self) -> Option<MonotonicDirection> {
68        match self {
69            Self::Sum | Self::Max | Self::Count | Self::CountAll | Self::Nor => {
70                Some(MonotonicDirection::NonDecreasing)
71            }
72            Self::Min | Self::Prod => Some(MonotonicDirection::NonIncreasing),
73            Self::Avg | Self::Collect => None,
74        }
75    }
76
77    /// Returns the identity element for this aggregate, or `None` for non-monotonic aggregates.
78    pub fn identity(&self) -> Option<f64> {
79        match self {
80            Self::Sum | Self::Count | Self::CountAll | Self::Nor => Some(0.0),
81            Self::Max => Some(f64::NEG_INFINITY),
82            Self::Min => Some(f64::INFINITY),
83            Self::Prod => Some(1.0),
84            Self::Avg | Self::Collect => None,
85        }
86    }
87}
88
89/// A single FOLD binding: aggregate an input column into an output column.
90#[derive(Debug, Clone)]
91pub struct FoldBinding {
92    pub output_name: String,
93    pub kind: FoldAggKind,
94    pub input_col_index: usize,
95}
96
97/// DataFusion `ExecutionPlan` that applies FOLD semantics.
98///
99/// Groups rows by KEY columns and computes aggregates (SUM, MAX, MIN, COUNT, AVG, COLLECT)
100/// for each fold binding. Output schema is KEY columns + fold output columns.
101#[derive(Debug)]
102pub struct FoldExec {
103    input: Arc<dyn ExecutionPlan>,
104    key_indices: Vec<usize>,
105    fold_bindings: Vec<FoldBinding>,
106    strict_probability_domain: bool,
107    probability_epsilon: f64,
108    schema: SchemaRef,
109    properties: PlanProperties,
110    metrics: ExecutionPlanMetricsSet,
111}
112
113impl FoldExec {
114    /// Create a new `FoldExec`.
115    ///
116    /// # Arguments
117    /// * `input` - Child execution plan
118    /// * `key_indices` - Indices of KEY columns for grouping
119    /// * `fold_bindings` - Aggregate bindings (output name, kind, input col index)
120    pub fn new(
121        input: Arc<dyn ExecutionPlan>,
122        key_indices: Vec<usize>,
123        fold_bindings: Vec<FoldBinding>,
124        strict_probability_domain: bool,
125        probability_epsilon: f64,
126    ) -> Self {
127        let input_schema = input.schema();
128        let schema = Self::build_output_schema(&input_schema, &key_indices, &fold_bindings);
129        let properties = compute_plan_properties(Arc::clone(&schema));
130
131        Self {
132            input,
133            key_indices,
134            fold_bindings,
135            strict_probability_domain,
136            probability_epsilon,
137            schema,
138            properties,
139            metrics: ExecutionPlanMetricsSet::new(),
140        }
141    }
142
143    fn build_output_schema(
144        input_schema: &SchemaRef,
145        key_indices: &[usize],
146        fold_bindings: &[FoldBinding],
147    ) -> SchemaRef {
148        let mut fields = Vec::new();
149
150        // Key columns preserve original types
151        for &ki in key_indices {
152            fields.push(Arc::new(input_schema.field(ki).clone()));
153        }
154
155        // Fold output columns
156        for binding in fold_bindings {
157            let output_type = match binding.kind {
158                FoldAggKind::Sum | FoldAggKind::Avg | FoldAggKind::Nor | FoldAggKind::Prod => {
159                    DataType::Float64
160                }
161                FoldAggKind::Count | FoldAggKind::CountAll => DataType::Int64,
162                FoldAggKind::Max | FoldAggKind::Min => input_schema
163                    .field(binding.input_col_index)
164                    .data_type()
165                    .clone(),
166                FoldAggKind::Collect => DataType::LargeBinary,
167            };
168            fields.push(Arc::new(Field::new(
169                &binding.output_name,
170                output_type,
171                true,
172            )));
173        }
174
175        Arc::new(Schema::new(fields))
176    }
177}
178
179impl DisplayAs for FoldExec {
180    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        write!(
182            f,
183            "FoldExec: key_indices={:?}, bindings={:?}",
184            self.key_indices, self.fold_bindings
185        )
186    }
187}
188
189impl ExecutionPlan for FoldExec {
190    fn name(&self) -> &str {
191        "FoldExec"
192    }
193
194    fn as_any(&self) -> &dyn Any {
195        self
196    }
197
198    fn schema(&self) -> SchemaRef {
199        Arc::clone(&self.schema)
200    }
201
202    fn properties(&self) -> &PlanProperties {
203        &self.properties
204    }
205
206    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
207        vec![&self.input]
208    }
209
210    fn with_new_children(
211        self: Arc<Self>,
212        children: Vec<Arc<dyn ExecutionPlan>>,
213    ) -> DFResult<Arc<dyn ExecutionPlan>> {
214        if children.len() != 1 {
215            return Err(datafusion::error::DataFusionError::Plan(
216                "FoldExec requires exactly one child".to_string(),
217            ));
218        }
219        Ok(Arc::new(Self::new(
220            Arc::clone(&children[0]),
221            self.key_indices.clone(),
222            self.fold_bindings.clone(),
223            self.strict_probability_domain,
224            self.probability_epsilon,
225        )))
226    }
227
228    fn execute(
229        &self,
230        partition: usize,
231        context: Arc<TaskContext>,
232    ) -> DFResult<SendableRecordBatchStream> {
233        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
234        let metrics = BaselineMetrics::new(&self.metrics, partition);
235        let key_indices = self.key_indices.clone();
236        let fold_bindings = self.fold_bindings.clone();
237        let strict = self.strict_probability_domain;
238        let epsilon = self.probability_epsilon;
239        let output_schema = Arc::clone(&self.schema);
240        let input_schema = self.input.schema();
241
242        let fut = async move {
243            let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
244
245            if batches.is_empty() {
246                return Ok(RecordBatch::new_empty(output_schema));
247            }
248
249            let batch =
250                arrow::compute::concat_batches(&input_schema, &batches).map_err(arrow_err)?;
251
252            if batch.num_rows() == 0 {
253                return Ok(RecordBatch::new_empty(output_schema));
254            }
255
256            // Group by key columns → row indices, preserving insertion order
257            let mut groups: HashMap<Vec<ScalarKey>, Vec<usize>> = HashMap::new();
258            let mut ordered_keys: Vec<Vec<ScalarKey>> = Vec::new();
259            for row_idx in 0..batch.num_rows() {
260                let key = extract_scalar_key(&batch, &key_indices, row_idx);
261                let entry = groups.entry(key.clone());
262                if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
263                    ordered_keys.push(key);
264                }
265                entry.or_default().push(row_idx);
266            }
267
268            let num_groups = ordered_keys.len();
269
270            // Build output columns
271            let mut output_columns: Vec<arrow_array::ArrayRef> = Vec::new();
272
273            // Key columns: take from first row of each group
274            for &ki in &key_indices {
275                let col = batch.column(ki);
276                let first_indices: Vec<u32> =
277                    ordered_keys.iter().map(|k| groups[k][0] as u32).collect();
278                let idx_array = arrow_array::UInt32Array::from(first_indices);
279                let taken = arrow::compute::take(col.as_ref(), &idx_array, None).map_err(|e| {
280                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
281                })?;
282                output_columns.push(taken);
283            }
284
285            // Fold binding columns: compute aggregates per group
286            for binding in &fold_bindings {
287                let col: Arc<dyn Array> = if binding.kind == FoldAggKind::CountAll {
288                    // CountAll doesn't need an input column — use a dummy
289                    Arc::new(arrow_array::Int64Array::from(vec![0i64; batch.num_rows()]))
290                } else {
291                    Arc::clone(batch.column(binding.input_col_index))
292                };
293                let agg_col = compute_fold_aggregate(
294                    col.as_ref(),
295                    &binding.kind,
296                    &ordered_keys,
297                    &groups,
298                    num_groups,
299                    strict,
300                    epsilon,
301                )?;
302                output_columns.push(agg_col);
303            }
304
305            RecordBatch::try_new(output_schema, output_columns).map_err(arrow_err)
306        };
307
308        Ok(Box::pin(FoldStream {
309            state: FoldStreamState::Running(Box::pin(fut)),
310            schema: Arc::clone(&self.schema),
311            metrics,
312        }))
313    }
314
315    fn metrics(&self) -> Option<MetricsSet> {
316        Some(self.metrics.clone_inner())
317    }
318}
319
320// ---------------------------------------------------------------------------
321// Aggregate computation
322// ---------------------------------------------------------------------------
323
324fn compute_fold_aggregate(
325    col: &dyn Array,
326    kind: &FoldAggKind,
327    ordered_keys: &[Vec<ScalarKey>],
328    groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
329    num_groups: usize,
330    strict: bool,
331    probability_epsilon: f64,
332) -> DFResult<arrow_array::ArrayRef> {
333    match kind {
334        FoldAggKind::Sum => {
335            let mut builder = Float64Builder::with_capacity(num_groups);
336            for key in ordered_keys {
337                builder.append_option(sum_f64(col, &groups[key]));
338            }
339            Ok(Arc::new(builder.finish()))
340        }
341        FoldAggKind::Count => {
342            let mut builder = Int64Builder::with_capacity(num_groups);
343            for key in ordered_keys {
344                let indices = &groups[key];
345                let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
346                builder.append_value(count as i64);
347            }
348            Ok(Arc::new(builder.finish()))
349        }
350        FoldAggKind::CountAll => {
351            let mut builder = Int64Builder::with_capacity(num_groups);
352            for key in ordered_keys {
353                let indices = &groups[key];
354                builder.append_value(indices.len() as i64);
355            }
356            Ok(Arc::new(builder.finish()))
357        }
358        FoldAggKind::Max => compute_minmax(col, ordered_keys, groups, num_groups, false),
359        FoldAggKind::Min => compute_minmax(col, ordered_keys, groups, num_groups, true),
360        FoldAggKind::Avg => {
361            let mut builder = Float64Builder::with_capacity(num_groups);
362            for key in ordered_keys {
363                let indices = &groups[key];
364                let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
365                let avg = sum_f64(col, indices)
366                    .filter(|_| count > 0)
367                    .map(|s| s / count as f64);
368                builder.append_option(avg);
369            }
370            Ok(Arc::new(builder.finish()))
371        }
372        FoldAggKind::Collect => {
373            let mut builder = LargeBinaryBuilder::with_capacity(num_groups, num_groups * 32);
374            for key in ordered_keys {
375                let values: Vec<uni_common::Value> = groups[key]
376                    .iter()
377                    .filter(|&&i| !col.is_null(i))
378                    .map(|&i| scalar_to_value(col, i))
379                    .collect();
380                let encoded =
381                    uni_common::cypher_value_codec::encode(&uni_common::Value::List(values));
382                builder.append_value(&encoded);
383            }
384            Ok(Arc::new(builder.finish()))
385        }
386        FoldAggKind::Nor => {
387            let mut builder = Float64Builder::with_capacity(num_groups);
388            for key in ordered_keys {
389                let indices = &groups[key];
390                builder.append_option(noisy_or_f64(col, indices, strict)?);
391            }
392            Ok(Arc::new(builder.finish()))
393        }
394        FoldAggKind::Prod => {
395            let mut builder = Float64Builder::with_capacity(num_groups);
396            for key in ordered_keys {
397                builder.append_option(product_f64(col, &groups[key], strict, probability_epsilon)?);
398            }
399            Ok(Arc::new(builder.finish()))
400        }
401    }
402}
403
404fn sum_f64(col: &dyn Array, indices: &[usize]) -> Option<f64> {
405    let mut sum = 0.0;
406    let mut has_value = false;
407    for &i in indices {
408        if col.is_null(i) {
409            continue;
410        }
411        has_value = true;
412        if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
413            sum += arr.value(i);
414        } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
415            sum += arr.value(i) as f64;
416        }
417    }
418    if has_value { Some(sum) } else { None }
419}
420
421/// Noisy-OR: P = 1 − ∏(1 − pᵢ). Inputs clamped to [0, 1] unless strict.
422fn noisy_or_f64(col: &dyn Array, indices: &[usize], strict: bool) -> DFResult<Option<f64>> {
423    let mut complement_product = 1.0;
424    let mut has_value = false;
425    for &i in indices {
426        if col.is_null(i) {
427            continue;
428        }
429        has_value = true;
430        let raw = if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
431            arr.value(i)
432        } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
433            arr.value(i) as f64
434        } else {
435            continue;
436        };
437        if strict && !(0.0..=1.0).contains(&raw) {
438            return Err(datafusion::error::DataFusionError::Execution(format!(
439                "strict_probability_domain: MNOR input {raw} is outside [0, 1]"
440            )));
441        }
442        if !strict && !(0.0..=1.0).contains(&raw) {
443            tracing::warn!(
444                "MNOR input {raw} outside [0,1], clamped to {}",
445                raw.clamp(0.0, 1.0)
446            );
447        }
448        let p = raw.clamp(0.0, 1.0);
449        complement_product *= 1.0 - p;
450    }
451    if has_value {
452        Ok(Some(1.0 - complement_product))
453    } else {
454        Ok(None)
455    }
456}
457
458/// Product: P = ∏ pᵢ. Inputs clamped to [0, 1] unless strict.
459///
460/// Switches to log-space when the running product drops below
461/// `probability_epsilon` to prevent floating-point underflow.
462fn product_f64(
463    col: &dyn Array,
464    indices: &[usize],
465    strict: bool,
466    probability_epsilon: f64,
467) -> DFResult<Option<f64>> {
468    let mut product = 1.0;
469    let mut log_sum = 0.0;
470    let mut use_log = false;
471    let mut has_value = false;
472    for &i in indices {
473        if col.is_null(i) {
474            continue;
475        }
476        has_value = true;
477        let raw = if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
478            arr.value(i)
479        } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
480            arr.value(i) as f64
481        } else {
482            continue;
483        };
484        if strict && !(0.0..=1.0).contains(&raw) {
485            return Err(datafusion::error::DataFusionError::Execution(format!(
486                "strict_probability_domain: MPROD input {raw} is outside [0, 1]"
487            )));
488        }
489        if !strict && !(0.0..=1.0).contains(&raw) {
490            tracing::warn!(
491                "MPROD input {raw} outside [0,1], clamped to {}",
492                raw.clamp(0.0, 1.0)
493            );
494        }
495        let p = raw.clamp(0.0, 1.0);
496        if p == 0.0 {
497            return Ok(Some(0.0));
498        }
499        if use_log {
500            log_sum += p.ln();
501        } else {
502            product *= p;
503            if product < probability_epsilon {
504                // Switch to log-space to prevent underflow
505                log_sum = product.ln();
506                use_log = true;
507            }
508        }
509    }
510    if !has_value {
511        return Ok(None);
512    }
513    if use_log {
514        Ok(Some(log_sum.exp()))
515    } else {
516        Ok(Some(product))
517    }
518}
519
520fn compute_minmax(
521    col: &dyn Array,
522    ordered_keys: &[Vec<ScalarKey>],
523    groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
524    num_groups: usize,
525    is_min: bool,
526) -> DFResult<arrow_array::ArrayRef> {
527    match col.data_type() {
528        DataType::Int64 => {
529            let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
530            let mut builder = Int64Builder::with_capacity(num_groups);
531            for key in ordered_keys {
532                let mut result: Option<i64> = None;
533                for &i in &groups[key] {
534                    if !arr.is_null(i) {
535                        let v = arr.value(i);
536                        result = Some(match result {
537                            None => v,
538                            Some(cur) if is_min => cur.min(v),
539                            Some(cur) => cur.max(v),
540                        });
541                    }
542                }
543                builder.append_option(result);
544            }
545            Ok(Arc::new(builder.finish()))
546        }
547        DataType::Float64 => {
548            let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
549            let mut builder = Float64Builder::with_capacity(num_groups);
550            for key in ordered_keys {
551                let mut result: Option<f64> = None;
552                for &i in &groups[key] {
553                    if !arr.is_null(i) {
554                        let v = arr.value(i);
555                        result = Some(match result {
556                            None => v,
557                            Some(cur) if is_min => cur.min(v),
558                            Some(cur) => cur.max(v),
559                        });
560                    }
561                }
562                builder.append_option(result);
563            }
564            Ok(Arc::new(builder.finish()))
565        }
566        dt => {
567            // Fallback: treat as string comparison.
568            // Use LargeStringBuilder for LargeUtf8 input to match the output schema
569            // (build_output_schema preserves the input type for MAX/MIN).
570            let use_large = matches!(dt, DataType::LargeUtf8);
571            let mut values: Vec<Option<String>> = Vec::with_capacity(num_groups);
572            for key in ordered_keys {
573                let indices = &groups[key];
574                let mut result: Option<String> = None;
575                for &i in indices {
576                    if col.is_null(i) {
577                        continue;
578                    }
579                    let v = format!("{:?}", scalar_to_value(col, i));
580                    result = Some(match result {
581                        None => v,
582                        Some(cur) if is_min && v < cur => v,
583                        Some(cur) if !is_min && v > cur => v,
584                        Some(cur) => cur,
585                    });
586                }
587                values.push(result);
588            }
589            Ok(build_optional_string_array(&values, use_large))
590        }
591    }
592}
593
594fn build_optional_string_array(
595    values: &[Option<String>],
596    use_large: bool,
597) -> arrow_array::ArrayRef {
598    if use_large {
599        let mut builder = arrow_array::builder::LargeStringBuilder::new();
600        for v in values {
601            match v {
602                Some(s) => builder.append_value(s),
603                None => builder.append_null(),
604            }
605        }
606        Arc::new(builder.finish())
607    } else {
608        let mut builder = arrow_array::builder::StringBuilder::new();
609        for v in values {
610            match v {
611                Some(s) => builder.append_value(s),
612                None => builder.append_null(),
613            }
614        }
615        Arc::new(builder.finish())
616    }
617}
618
619fn scalar_to_value(col: &dyn Array, row_idx: usize) -> uni_common::Value {
620    if col.is_null(row_idx) {
621        return uni_common::Value::Null;
622    }
623    match col.data_type() {
624        DataType::Int64 => {
625            let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
626            uni_common::Value::Int(arr.value(row_idx))
627        }
628        DataType::Float64 => {
629            let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
630            uni_common::Value::Float(arr.value(row_idx))
631        }
632        DataType::Utf8 => {
633            let arr = col
634                .as_any()
635                .downcast_ref::<arrow_array::StringArray>()
636                .unwrap();
637            uni_common::Value::String(arr.value(row_idx).to_string())
638        }
639        DataType::LargeUtf8 => {
640            let arr = col
641                .as_any()
642                .downcast_ref::<arrow_array::LargeStringArray>()
643                .unwrap();
644            uni_common::Value::String(arr.value(row_idx).to_string())
645        }
646        DataType::Boolean => {
647            let arr = col
648                .as_any()
649                .downcast_ref::<arrow_array::BooleanArray>()
650                .unwrap();
651            uni_common::Value::Bool(arr.value(row_idx))
652        }
653        DataType::LargeBinary => {
654            let arr = col
655                .as_any()
656                .downcast_ref::<arrow_array::LargeBinaryArray>()
657                .unwrap();
658            let bytes = arr.value(row_idx);
659            uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null)
660        }
661        _ => uni_common::Value::Null,
662    }
663}
664
665// ---------------------------------------------------------------------------
666// Stream implementation
667// ---------------------------------------------------------------------------
668
669enum FoldStreamState {
670    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
671    Done,
672}
673
674struct FoldStream {
675    state: FoldStreamState,
676    schema: SchemaRef,
677    metrics: BaselineMetrics,
678}
679
680impl Stream for FoldStream {
681    type Item = DFResult<RecordBatch>;
682
683    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
684        match &mut self.state {
685            FoldStreamState::Running(fut) => match fut.as_mut().poll(cx) {
686                Poll::Ready(Ok(batch)) => {
687                    self.metrics.record_output(batch.num_rows());
688                    self.state = FoldStreamState::Done;
689                    Poll::Ready(Some(Ok(batch)))
690                }
691                Poll::Ready(Err(e)) => {
692                    self.state = FoldStreamState::Done;
693                    Poll::Ready(Some(Err(e)))
694                }
695                Poll::Pending => Poll::Pending,
696            },
697            FoldStreamState::Done => Poll::Ready(None),
698        }
699    }
700}
701
702impl RecordBatchStream for FoldStream {
703    fn schema(&self) -> SchemaRef {
704        Arc::clone(&self.schema)
705    }
706}
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711    use arrow_array::{Float64Array, Int64Array, StringArray};
712    use arrow_schema::{DataType, Field, Schema};
713    use datafusion::physical_plan::memory::MemoryStream;
714    use datafusion::prelude::SessionContext;
715
716    fn make_test_batch(names: Vec<&str>, values: Vec<f64>) -> RecordBatch {
717        let schema = Arc::new(Schema::new(vec![
718            Field::new("name", DataType::Utf8, true),
719            Field::new("value", DataType::Float64, true),
720        ]));
721        RecordBatch::try_new(
722            schema,
723            vec![
724                Arc::new(StringArray::from(
725                    names.into_iter().map(Some).collect::<Vec<_>>(),
726                )),
727                Arc::new(Float64Array::from(values)),
728            ],
729        )
730        .unwrap()
731    }
732
733    fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
734        let schema = batch.schema();
735        Arc::new(TestMemoryExec {
736            batches: vec![batch],
737            schema: schema.clone(),
738            properties: compute_plan_properties(schema),
739        })
740    }
741
742    #[derive(Debug)]
743    struct TestMemoryExec {
744        batches: Vec<RecordBatch>,
745        schema: SchemaRef,
746        properties: PlanProperties,
747    }
748
749    impl DisplayAs for TestMemoryExec {
750        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
751            write!(f, "TestMemoryExec")
752        }
753    }
754
755    impl ExecutionPlan for TestMemoryExec {
756        fn name(&self) -> &str {
757            "TestMemoryExec"
758        }
759        fn as_any(&self) -> &dyn Any {
760            self
761        }
762        fn schema(&self) -> SchemaRef {
763            Arc::clone(&self.schema)
764        }
765        fn properties(&self) -> &PlanProperties {
766            &self.properties
767        }
768        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
769            vec![]
770        }
771        fn with_new_children(
772            self: Arc<Self>,
773            _children: Vec<Arc<dyn ExecutionPlan>>,
774        ) -> DFResult<Arc<dyn ExecutionPlan>> {
775            Ok(self)
776        }
777        fn execute(
778            &self,
779            _partition: usize,
780            _context: Arc<TaskContext>,
781        ) -> DFResult<SendableRecordBatchStream> {
782            Ok(Box::pin(MemoryStream::try_new(
783                self.batches.clone(),
784                Arc::clone(&self.schema),
785                None,
786            )?))
787        }
788    }
789
790    async fn execute_fold(
791        input: Arc<dyn ExecutionPlan>,
792        key_indices: Vec<usize>,
793        fold_bindings: Vec<FoldBinding>,
794    ) -> RecordBatch {
795        let exec = FoldExec::new(input, key_indices, fold_bindings, false, 1e-15);
796        let ctx = SessionContext::new();
797        let task_ctx = ctx.task_ctx();
798        let stream = exec.execute(0, task_ctx).unwrap();
799        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
800            .await
801            .unwrap();
802        if batches.is_empty() {
803            RecordBatch::new_empty(exec.schema())
804        } else {
805            arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
806        }
807    }
808
809    #[tokio::test]
810    async fn test_sum_single_group() {
811        let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
812        let input = make_memory_exec(batch);
813        let result = execute_fold(
814            input,
815            vec![0],
816            vec![FoldBinding {
817                output_name: "total".to_string(),
818                kind: FoldAggKind::Sum,
819                input_col_index: 1,
820            }],
821        )
822        .await;
823
824        assert_eq!(result.num_rows(), 1);
825        let totals = result
826            .column(1)
827            .as_any()
828            .downcast_ref::<Float64Array>()
829            .unwrap();
830        assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
831    }
832
833    #[tokio::test]
834    async fn test_count_non_null() {
835        let schema = Arc::new(Schema::new(vec![
836            Field::new("name", DataType::Utf8, true),
837            Field::new("value", DataType::Float64, true),
838        ]));
839        let batch = RecordBatch::try_new(
840            schema,
841            vec![
842                Arc::new(StringArray::from(vec![Some("a"), Some("a"), Some("a")])),
843                Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])),
844            ],
845        )
846        .unwrap();
847        let input = make_memory_exec(batch);
848        let result = execute_fold(
849            input,
850            vec![0],
851            vec![FoldBinding {
852                output_name: "cnt".to_string(),
853                kind: FoldAggKind::Count,
854                input_col_index: 1,
855            }],
856        )
857        .await;
858
859        assert_eq!(result.num_rows(), 1);
860        let counts = result
861            .column(1)
862            .as_any()
863            .downcast_ref::<Int64Array>()
864            .unwrap();
865        assert_eq!(counts.value(0), 2); // null not counted
866    }
867
868    #[tokio::test]
869    async fn test_max_min() {
870        let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 5.0]);
871        let input_max = make_memory_exec(batch.clone());
872        let input_min = make_memory_exec(batch);
873
874        let result_max = execute_fold(
875            input_max,
876            vec![0],
877            vec![FoldBinding {
878                output_name: "mx".to_string(),
879                kind: FoldAggKind::Max,
880                input_col_index: 1,
881            }],
882        )
883        .await;
884        let result_min = execute_fold(
885            input_min,
886            vec![0],
887            vec![FoldBinding {
888                output_name: "mn".to_string(),
889                kind: FoldAggKind::Min,
890                input_col_index: 1,
891            }],
892        )
893        .await;
894
895        let max_vals = result_max
896            .column(1)
897            .as_any()
898            .downcast_ref::<Float64Array>()
899            .unwrap();
900        assert_eq!(max_vals.value(0), 5.0);
901
902        let min_vals = result_min
903            .column(1)
904            .as_any()
905            .downcast_ref::<Float64Array>()
906            .unwrap();
907        assert_eq!(min_vals.value(0), 1.0);
908    }
909
910    #[tokio::test]
911    async fn test_avg() {
912        let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![2.0, 4.0, 6.0, 8.0]);
913        let input = make_memory_exec(batch);
914        let result = execute_fold(
915            input,
916            vec![0],
917            vec![FoldBinding {
918                output_name: "average".to_string(),
919                kind: FoldAggKind::Avg,
920                input_col_index: 1,
921            }],
922        )
923        .await;
924
925        assert_eq!(result.num_rows(), 1);
926        let avgs = result
927            .column(1)
928            .as_any()
929            .downcast_ref::<Float64Array>()
930            .unwrap();
931        assert!((avgs.value(0) - 5.0).abs() < f64::EPSILON);
932    }
933
934    #[tokio::test]
935    async fn test_multiple_groups() {
936        let batch = make_test_batch(
937            vec!["a", "a", "b", "b", "b"],
938            vec![1.0, 2.0, 10.0, 20.0, 30.0],
939        );
940        let input = make_memory_exec(batch);
941        let result = execute_fold(
942            input,
943            vec![0],
944            vec![FoldBinding {
945                output_name: "total".to_string(),
946                kind: FoldAggKind::Sum,
947                input_col_index: 1,
948            }],
949        )
950        .await;
951
952        assert_eq!(result.num_rows(), 2);
953        let names = result
954            .column(0)
955            .as_any()
956            .downcast_ref::<StringArray>()
957            .unwrap();
958        let totals = result
959            .column(1)
960            .as_any()
961            .downcast_ref::<Float64Array>()
962            .unwrap();
963
964        for i in 0..2 {
965            match names.value(i) {
966                "a" => assert!((totals.value(i) - 3.0).abs() < f64::EPSILON),
967                "b" => assert!((totals.value(i) - 60.0).abs() < f64::EPSILON),
968                _ => panic!("unexpected name"),
969            }
970        }
971    }
972
973    #[tokio::test]
974    async fn test_empty_input() {
975        let schema = Arc::new(Schema::new(vec![
976            Field::new("name", DataType::Utf8, true),
977            Field::new("value", DataType::Float64, true),
978        ]));
979        let batch = RecordBatch::new_empty(schema);
980        let input = make_memory_exec(batch);
981        let result = execute_fold(
982            input,
983            vec![0],
984            vec![FoldBinding {
985                output_name: "total".to_string(),
986                kind: FoldAggKind::Sum,
987                input_col_index: 1,
988            }],
989        )
990        .await;
991
992        assert_eq!(result.num_rows(), 0);
993    }
994
995    #[tokio::test]
996    async fn test_multiple_bindings() {
997        let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
998        let input = make_memory_exec(batch);
999        let result = execute_fold(
1000            input,
1001            vec![0],
1002            vec![
1003                FoldBinding {
1004                    output_name: "total".to_string(),
1005                    kind: FoldAggKind::Sum,
1006                    input_col_index: 1,
1007                },
1008                FoldBinding {
1009                    output_name: "cnt".to_string(),
1010                    kind: FoldAggKind::Count,
1011                    input_col_index: 1,
1012                },
1013                FoldBinding {
1014                    output_name: "mx".to_string(),
1015                    kind: FoldAggKind::Max,
1016                    input_col_index: 1,
1017                },
1018            ],
1019        )
1020        .await;
1021
1022        assert_eq!(result.num_rows(), 1);
1023        assert_eq!(result.num_columns(), 4); // name + total + cnt + mx
1024
1025        let totals = result
1026            .column(1)
1027            .as_any()
1028            .downcast_ref::<Float64Array>()
1029            .unwrap();
1030        assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
1031
1032        let counts = result
1033            .column(2)
1034            .as_any()
1035            .downcast_ref::<Int64Array>()
1036            .unwrap();
1037        assert_eq!(counts.value(0), 3);
1038
1039        let maxes = result
1040            .column(3)
1041            .as_any()
1042            .downcast_ref::<Float64Array>()
1043            .unwrap();
1044        assert_eq!(maxes.value(0), 3.0);
1045    }
1046
1047    // ── MNOR tests ───────────────────────────────────────────────────────
1048
1049    #[tokio::test]
1050    async fn test_nor_single_group() {
1051        // MNOR({0.3, 0.5}) = 1 - (1-0.3)*(1-0.5) = 1 - 0.7*0.5 = 1 - 0.35 = 0.65
1052        let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
1053        let input = make_memory_exec(batch);
1054        let result = execute_fold(
1055            input,
1056            vec![0],
1057            vec![FoldBinding {
1058                output_name: "prob".to_string(),
1059                kind: FoldAggKind::Nor,
1060                input_col_index: 1,
1061            }],
1062        )
1063        .await;
1064
1065        assert_eq!(result.num_rows(), 1);
1066        let vals = result
1067            .column(1)
1068            .as_any()
1069            .downcast_ref::<Float64Array>()
1070            .unwrap();
1071        assert!((vals.value(0) - 0.65).abs() < 1e-10);
1072    }
1073
1074    #[tokio::test]
1075    async fn test_nor_identity() {
1076        // MNOR({0.0, 0.0}) = 1 - (1-0)*(1-0) = 1 - 1 = 0.0
1077        let batch = make_test_batch(vec!["a", "a"], vec![0.0, 0.0]);
1078        let input = make_memory_exec(batch);
1079        let result = execute_fold(
1080            input,
1081            vec![0],
1082            vec![FoldBinding {
1083                output_name: "prob".to_string(),
1084                kind: FoldAggKind::Nor,
1085                input_col_index: 1,
1086            }],
1087        )
1088        .await;
1089
1090        let vals = result
1091            .column(1)
1092            .as_any()
1093            .downcast_ref::<Float64Array>()
1094            .unwrap();
1095        assert!((vals.value(0) - 0.0).abs() < 1e-10);
1096    }
1097
1098    #[tokio::test]
1099    async fn test_nor_clamping() {
1100        // Out-of-range values should be clamped to [0, 1]
1101        let batch = make_test_batch(vec!["a", "a"], vec![-0.5, 1.5]);
1102        let input = make_memory_exec(batch);
1103        let result = execute_fold(
1104            input,
1105            vec![0],
1106            vec![FoldBinding {
1107                output_name: "prob".to_string(),
1108                kind: FoldAggKind::Nor,
1109                input_col_index: 1,
1110            }],
1111        )
1112        .await;
1113
1114        let vals = result
1115            .column(1)
1116            .as_any()
1117            .downcast_ref::<Float64Array>()
1118            .unwrap();
1119        // Clamped to (0.0, 1.0): MNOR = 1 - (1-0)*(1-1) = 1 - 1*0 = 1.0
1120        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1121    }
1122
1123    #[tokio::test]
1124    async fn test_nor_multiple_groups() {
1125        let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.3, 0.5, 0.1, 0.2]);
1126        let input = make_memory_exec(batch);
1127        let result = execute_fold(
1128            input,
1129            vec![0],
1130            vec![FoldBinding {
1131                output_name: "prob".to_string(),
1132                kind: FoldAggKind::Nor,
1133                input_col_index: 1,
1134            }],
1135        )
1136        .await;
1137
1138        assert_eq!(result.num_rows(), 2);
1139        let names = result
1140            .column(0)
1141            .as_any()
1142            .downcast_ref::<StringArray>()
1143            .unwrap();
1144        let vals = result
1145            .column(1)
1146            .as_any()
1147            .downcast_ref::<Float64Array>()
1148            .unwrap();
1149
1150        for i in 0..2 {
1151            match names.value(i) {
1152                // MNOR({0.3, 0.5}) = 0.65
1153                "a" => assert!((vals.value(i) - 0.65).abs() < 1e-10),
1154                // MNOR({0.1, 0.2}) = 1 - 0.9*0.8 = 1 - 0.72 = 0.28
1155                "b" => assert!((vals.value(i) - 0.28).abs() < 1e-10),
1156                _ => panic!("unexpected name"),
1157            }
1158        }
1159    }
1160
1161    // ── MPROD tests ──────────────────────────────────────────────────────
1162
1163    #[tokio::test]
1164    async fn test_prod_single_group() {
1165        // MPROD({0.6, 0.8}) = 0.48
1166        let batch = make_test_batch(vec!["a", "a"], vec![0.6, 0.8]);
1167        let input = make_memory_exec(batch);
1168        let result = execute_fold(
1169            input,
1170            vec![0],
1171            vec![FoldBinding {
1172                output_name: "prob".to_string(),
1173                kind: FoldAggKind::Prod,
1174                input_col_index: 1,
1175            }],
1176        )
1177        .await;
1178
1179        assert_eq!(result.num_rows(), 1);
1180        let vals = result
1181            .column(1)
1182            .as_any()
1183            .downcast_ref::<Float64Array>()
1184            .unwrap();
1185        assert!((vals.value(0) - 0.48).abs() < 1e-10);
1186    }
1187
1188    #[tokio::test]
1189    async fn test_prod_identity() {
1190        // MPROD({1.0, 1.0}) = 1.0
1191        let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0]);
1192        let input = make_memory_exec(batch);
1193        let result = execute_fold(
1194            input,
1195            vec![0],
1196            vec![FoldBinding {
1197                output_name: "prob".to_string(),
1198                kind: FoldAggKind::Prod,
1199                input_col_index: 1,
1200            }],
1201        )
1202        .await;
1203
1204        let vals = result
1205            .column(1)
1206            .as_any()
1207            .downcast_ref::<Float64Array>()
1208            .unwrap();
1209        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1210    }
1211
1212    #[tokio::test]
1213    async fn test_prod_zero_absorbing() {
1214        // MPROD with 0.0 = 0.0 (zero is absorbing element)
1215        let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.0, 0.8]);
1216        let input = make_memory_exec(batch);
1217        let result = execute_fold(
1218            input,
1219            vec![0],
1220            vec![FoldBinding {
1221                output_name: "prob".to_string(),
1222                kind: FoldAggKind::Prod,
1223                input_col_index: 1,
1224            }],
1225        )
1226        .await;
1227
1228        let vals = result
1229            .column(1)
1230            .as_any()
1231            .downcast_ref::<Float64Array>()
1232            .unwrap();
1233        assert!((vals.value(0) - 0.0).abs() < 1e-10);
1234    }
1235
1236    #[tokio::test]
1237    async fn test_prod_underflow_protection() {
1238        // 50 × 0.5 ≈ 8.88e-16, should not be exactly 0 thanks to log-space
1239        let names: Vec<&str> = vec!["a"; 50];
1240        let values: Vec<f64> = vec![0.5; 50];
1241        let batch = make_test_batch(names, values);
1242        let input = make_memory_exec(batch);
1243        let result = execute_fold(
1244            input,
1245            vec![0],
1246            vec![FoldBinding {
1247                output_name: "prob".to_string(),
1248                kind: FoldAggKind::Prod,
1249                input_col_index: 1,
1250            }],
1251        )
1252        .await;
1253
1254        let vals = result
1255            .column(1)
1256            .as_any()
1257            .downcast_ref::<Float64Array>()
1258            .unwrap();
1259        let expected = 0.5_f64.powi(50); // ≈ 8.88e-16
1260        assert!(vals.value(0) > 0.0, "should not underflow to zero");
1261        assert!(
1262            (vals.value(0) - expected).abs() / expected < 1e-6,
1263            "result {} should be close to expected {}",
1264            vals.value(0),
1265            expected
1266        );
1267    }
1268
1269    // ── MNOR/MPROD mathematical correctness tests ───────────────────────
1270
1271    fn make_nullable_test_batch(names: Vec<&str>, values: Vec<Option<f64>>) -> RecordBatch {
1272        let schema = Arc::new(Schema::new(vec![
1273            Field::new("name", DataType::Utf8, true),
1274            Field::new("value", DataType::Float64, true),
1275        ]));
1276        RecordBatch::try_new(
1277            schema,
1278            vec![
1279                Arc::new(StringArray::from(
1280                    names.into_iter().map(Some).collect::<Vec<_>>(),
1281                )),
1282                Arc::new(Float64Array::from(values)),
1283            ],
1284        )
1285        .unwrap()
1286    }
1287
1288    #[tokio::test]
1289    async fn test_nor_single_element() {
1290        // MNOR({0.7}) = 0.7 (n=1 identity)
1291        let batch = make_test_batch(vec!["a"], vec![0.7]);
1292        let input = make_memory_exec(batch);
1293        let result = execute_fold(
1294            input,
1295            vec![0],
1296            vec![FoldBinding {
1297                output_name: "prob".to_string(),
1298                kind: FoldAggKind::Nor,
1299                input_col_index: 1,
1300            }],
1301        )
1302        .await;
1303        let vals = result
1304            .column(1)
1305            .as_any()
1306            .downcast_ref::<Float64Array>()
1307            .unwrap();
1308        assert!((vals.value(0) - 0.7).abs() < 1e-10);
1309    }
1310
1311    #[tokio::test]
1312    async fn test_prod_single_element() {
1313        // MPROD({0.7}) = 0.7 (n=1 identity)
1314        let batch = make_test_batch(vec!["a"], vec![0.7]);
1315        let input = make_memory_exec(batch);
1316        let result = execute_fold(
1317            input,
1318            vec![0],
1319            vec![FoldBinding {
1320                output_name: "prob".to_string(),
1321                kind: FoldAggKind::Prod,
1322                input_col_index: 1,
1323            }],
1324        )
1325        .await;
1326        let vals = result
1327            .column(1)
1328            .as_any()
1329            .downcast_ref::<Float64Array>()
1330            .unwrap();
1331        assert!((vals.value(0) - 0.7).abs() < 1e-10);
1332    }
1333
1334    #[tokio::test]
1335    async fn test_nor_three_elements() {
1336        // MNOR({0.3, 0.4, 0.5}) = 1 - (0.7)(0.6)(0.5) = 0.79
1337        let batch = make_test_batch(vec!["a", "a", "a"], vec![0.3, 0.4, 0.5]);
1338        let input = make_memory_exec(batch);
1339        let result = execute_fold(
1340            input,
1341            vec![0],
1342            vec![FoldBinding {
1343                output_name: "prob".to_string(),
1344                kind: FoldAggKind::Nor,
1345                input_col_index: 1,
1346            }],
1347        )
1348        .await;
1349        let vals = result
1350            .column(1)
1351            .as_any()
1352            .downcast_ref::<Float64Array>()
1353            .unwrap();
1354        assert!((vals.value(0) - 0.79).abs() < 1e-10);
1355    }
1356
1357    #[tokio::test]
1358    async fn test_nor_four_elements_spec_example() {
1359        // Spec §4.5: MNOR({0.72, 0.54, 0.56, 0.42}) = 1 - (0.28)(0.46)(0.44)(0.58) = 0.96713024
1360        let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![0.72, 0.54, 0.56, 0.42]);
1361        let input = make_memory_exec(batch);
1362        let result = execute_fold(
1363            input,
1364            vec![0],
1365            vec![FoldBinding {
1366                output_name: "prob".to_string(),
1367                kind: FoldAggKind::Nor,
1368                input_col_index: 1,
1369            }],
1370        )
1371        .await;
1372        let vals = result
1373            .column(1)
1374            .as_any()
1375            .downcast_ref::<Float64Array>()
1376            .unwrap();
1377        assert!(
1378            (vals.value(0) - 0.96713024).abs() < 1e-10,
1379            "expected 0.96713024, got {}",
1380            vals.value(0)
1381        );
1382    }
1383
1384    #[tokio::test]
1385    async fn test_prod_three_elements() {
1386        // MPROD({0.5, 0.5, 0.5}) = 0.125
1387        let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.5, 0.5]);
1388        let input = make_memory_exec(batch);
1389        let result = execute_fold(
1390            input,
1391            vec![0],
1392            vec![FoldBinding {
1393                output_name: "prob".to_string(),
1394                kind: FoldAggKind::Prod,
1395                input_col_index: 1,
1396            }],
1397        )
1398        .await;
1399        let vals = result
1400            .column(1)
1401            .as_any()
1402            .downcast_ref::<Float64Array>()
1403            .unwrap();
1404        assert!((vals.value(0) - 0.125).abs() < 1e-10);
1405    }
1406
1407    #[tokio::test]
1408    async fn test_nor_absorbing_element() {
1409        // p=1.0 absorbs: MNOR({0.3, 1.0}) = 1.0
1410        let batch = make_test_batch(vec!["a", "a"], vec![0.3, 1.0]);
1411        let input = make_memory_exec(batch);
1412        let result = execute_fold(
1413            input,
1414            vec![0],
1415            vec![FoldBinding {
1416                output_name: "prob".to_string(),
1417                kind: FoldAggKind::Nor,
1418                input_col_index: 1,
1419            }],
1420        )
1421        .await;
1422        let vals = result
1423            .column(1)
1424            .as_any()
1425            .downcast_ref::<Float64Array>()
1426            .unwrap();
1427        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1428    }
1429
1430    #[tokio::test]
1431    async fn test_prod_clamping() {
1432        // Out-of-range 2.0 clamped to 1.0: MPROD({2.0, 0.5}) = 1.0 * 0.5 = 0.5
1433        let batch = make_test_batch(vec!["a", "a"], vec![2.0, 0.5]);
1434        let input = make_memory_exec(batch);
1435        let result = execute_fold(
1436            input,
1437            vec![0],
1438            vec![FoldBinding {
1439                output_name: "prob".to_string(),
1440                kind: FoldAggKind::Prod,
1441                input_col_index: 1,
1442            }],
1443        )
1444        .await;
1445        let vals = result
1446            .column(1)
1447            .as_any()
1448            .downcast_ref::<Float64Array>()
1449            .unwrap();
1450        assert!((vals.value(0) - 0.5).abs() < 1e-10);
1451    }
1452
1453    #[tokio::test]
1454    async fn test_prod_multiple_groups() {
1455        // a: MPROD({0.6, 0.8}) = 0.48, b: MPROD({0.5, 0.5}) = 0.25
1456        let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.6, 0.8, 0.5, 0.5]);
1457        let input = make_memory_exec(batch);
1458        let result = execute_fold(
1459            input,
1460            vec![0],
1461            vec![FoldBinding {
1462                output_name: "prob".to_string(),
1463                kind: FoldAggKind::Prod,
1464                input_col_index: 1,
1465            }],
1466        )
1467        .await;
1468
1469        assert_eq!(result.num_rows(), 2);
1470        let names = result
1471            .column(0)
1472            .as_any()
1473            .downcast_ref::<StringArray>()
1474            .unwrap();
1475        let vals = result
1476            .column(1)
1477            .as_any()
1478            .downcast_ref::<Float64Array>()
1479            .unwrap();
1480        for i in 0..2 {
1481            match names.value(i) {
1482                "a" => assert!((vals.value(i) - 0.48).abs() < 1e-10),
1483                "b" => assert!((vals.value(i) - 0.25).abs() < 1e-10),
1484                _ => panic!("unexpected group name"),
1485            }
1486        }
1487    }
1488
1489    #[tokio::test]
1490    async fn test_nor_commutativity() {
1491        // Order independence: MNOR({0.2, 0.5, 0.8}) = MNOR({0.8, 0.5, 0.2}) = 0.92
1492        let fwd = make_test_batch(vec!["a", "a", "a"], vec![0.2, 0.5, 0.8]);
1493        let rev = make_test_batch(vec!["a", "a", "a"], vec![0.8, 0.5, 0.2]);
1494        let binding = vec![FoldBinding {
1495            output_name: "prob".to_string(),
1496            kind: FoldAggKind::Nor,
1497            input_col_index: 1,
1498        }];
1499        let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1500        let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1501        let v1 = r1
1502            .column(1)
1503            .as_any()
1504            .downcast_ref::<Float64Array>()
1505            .unwrap()
1506            .value(0);
1507        let v2 = r2
1508            .column(1)
1509            .as_any()
1510            .downcast_ref::<Float64Array>()
1511            .unwrap()
1512            .value(0);
1513        assert!((v1 - 0.92).abs() < 1e-10);
1514        assert!((v2 - 0.92).abs() < 1e-10);
1515        assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1516    }
1517
1518    #[tokio::test]
1519    async fn test_prod_commutativity() {
1520        // Order independence: MPROD({0.5, 0.25}) = MPROD({0.25, 0.5}) = 0.125
1521        let fwd = make_test_batch(vec!["a", "a"], vec![0.5, 0.25]);
1522        let rev = make_test_batch(vec!["a", "a"], vec![0.25, 0.5]);
1523        let binding = vec![FoldBinding {
1524            output_name: "prob".to_string(),
1525            kind: FoldAggKind::Prod,
1526            input_col_index: 1,
1527        }];
1528        let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1529        let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1530        let v1 = r1
1531            .column(1)
1532            .as_any()
1533            .downcast_ref::<Float64Array>()
1534            .unwrap()
1535            .value(0);
1536        let v2 = r2
1537            .column(1)
1538            .as_any()
1539            .downcast_ref::<Float64Array>()
1540            .unwrap()
1541            .value(0);
1542        assert!((v1 - 0.125).abs() < 1e-10);
1543        assert!((v2 - 0.125).abs() < 1e-10);
1544        assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1545    }
1546
1547    #[tokio::test]
1548    async fn test_nor_boundary_near_zero() {
1549        // Precision near 0: MNOR({0.001, 0.002}) = 1 - (0.999)(0.998) = 0.002998
1550        let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1551        let input = make_memory_exec(batch);
1552        let result = execute_fold(
1553            input,
1554            vec![0],
1555            vec![FoldBinding {
1556                output_name: "prob".to_string(),
1557                kind: FoldAggKind::Nor,
1558                input_col_index: 1,
1559            }],
1560        )
1561        .await;
1562        let vals = result
1563            .column(1)
1564            .as_any()
1565            .downcast_ref::<Float64Array>()
1566            .unwrap();
1567        let expected = 1.0 - 0.999 * 0.998;
1568        assert!(
1569            (vals.value(0) - expected).abs() < 1e-10,
1570            "expected {}, got {}",
1571            expected,
1572            vals.value(0)
1573        );
1574    }
1575
1576    #[tokio::test]
1577    async fn test_nor_boundary_near_one() {
1578        // Precision near 1: MNOR({0.999, 0.998}) = 1 - (0.001)(0.002) = 0.999998
1579        let batch = make_test_batch(vec!["a", "a"], vec![0.999, 0.998]);
1580        let input = make_memory_exec(batch);
1581        let result = execute_fold(
1582            input,
1583            vec![0],
1584            vec![FoldBinding {
1585                output_name: "prob".to_string(),
1586                kind: FoldAggKind::Nor,
1587                input_col_index: 1,
1588            }],
1589        )
1590        .await;
1591        let vals = result
1592            .column(1)
1593            .as_any()
1594            .downcast_ref::<Float64Array>()
1595            .unwrap();
1596        let expected = 1.0 - 0.001 * 0.002;
1597        assert!(
1598            (vals.value(0) - expected).abs() < 1e-10,
1599            "expected {}, got {}",
1600            expected,
1601            vals.value(0)
1602        );
1603    }
1604
1605    #[tokio::test]
1606    async fn test_prod_boundary_near_zero() {
1607        // Precision near 0: MPROD({0.001, 0.002}) = 2e-6
1608        let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1609        let input = make_memory_exec(batch);
1610        let result = execute_fold(
1611            input,
1612            vec![0],
1613            vec![FoldBinding {
1614                output_name: "prob".to_string(),
1615                kind: FoldAggKind::Prod,
1616                input_col_index: 1,
1617            }],
1618        )
1619        .await;
1620        let vals = result
1621            .column(1)
1622            .as_any()
1623            .downcast_ref::<Float64Array>()
1624            .unwrap();
1625        assert!(
1626            (vals.value(0) - 2e-6).abs() < 1e-15,
1627            "expected 2e-6, got {}",
1628            vals.value(0)
1629        );
1630    }
1631
1632    #[tokio::test]
1633    async fn test_nor_empty_input() {
1634        // Empty input → 0 rows output
1635        let schema = Arc::new(Schema::new(vec![
1636            Field::new("name", DataType::Utf8, true),
1637            Field::new("value", DataType::Float64, true),
1638        ]));
1639        let batch = RecordBatch::new_empty(schema);
1640        let input = make_memory_exec(batch);
1641        let result = execute_fold(
1642            input,
1643            vec![0],
1644            vec![FoldBinding {
1645                output_name: "prob".to_string(),
1646                kind: FoldAggKind::Nor,
1647                input_col_index: 1,
1648            }],
1649        )
1650        .await;
1651        assert_eq!(result.num_rows(), 0);
1652    }
1653
1654    #[tokio::test]
1655    async fn test_nor_nan_handling() {
1656        // NaN propagates through noisy-OR
1657        let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::NAN]);
1658        let input = make_memory_exec(batch);
1659        let result = execute_fold(
1660            input,
1661            vec![0],
1662            vec![FoldBinding {
1663                output_name: "prob".to_string(),
1664                kind: FoldAggKind::Nor,
1665                input_col_index: 1,
1666            }],
1667        )
1668        .await;
1669        let vals = result
1670            .column(1)
1671            .as_any()
1672            .downcast_ref::<Float64Array>()
1673            .unwrap();
1674        assert!(vals.value(0).is_nan(), "NaN should propagate through MNOR");
1675    }
1676
1677    #[tokio::test]
1678    async fn test_prod_nan_handling() {
1679        // NaN propagates through product
1680        let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::NAN]);
1681        let input = make_memory_exec(batch);
1682        let result = execute_fold(
1683            input,
1684            vec![0],
1685            vec![FoldBinding {
1686                output_name: "prob".to_string(),
1687                kind: FoldAggKind::Prod,
1688                input_col_index: 1,
1689            }],
1690        )
1691        .await;
1692        let vals = result
1693            .column(1)
1694            .as_any()
1695            .downcast_ref::<Float64Array>()
1696            .unwrap();
1697        assert!(vals.value(0).is_nan(), "NaN should propagate through MPROD");
1698    }
1699
1700    #[tokio::test]
1701    async fn test_prod_infinity_handling() {
1702        // +∞ clamped to 1.0: MPROD({0.5, ∞}) = 0.5 * 1.0 = 0.5
1703        let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::INFINITY]);
1704        let input = make_memory_exec(batch);
1705        let result = execute_fold(
1706            input,
1707            vec![0],
1708            vec![FoldBinding {
1709                output_name: "prob".to_string(),
1710                kind: FoldAggKind::Prod,
1711                input_col_index: 1,
1712            }],
1713        )
1714        .await;
1715        let vals = result
1716            .column(1)
1717            .as_any()
1718            .downcast_ref::<Float64Array>()
1719            .unwrap();
1720        assert!((vals.value(0) - 0.5).abs() < 1e-10);
1721    }
1722
1723    #[tokio::test]
1724    async fn test_nor_infinity_handling() {
1725        // +∞ clamped to 1.0, which absorbs: MNOR({0.3, ∞}) = 1.0
1726        let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::INFINITY]);
1727        let input = make_memory_exec(batch);
1728        let result = execute_fold(
1729            input,
1730            vec![0],
1731            vec![FoldBinding {
1732                output_name: "prob".to_string(),
1733                kind: FoldAggKind::Nor,
1734                input_col_index: 1,
1735            }],
1736        )
1737        .await;
1738        let vals = result
1739            .column(1)
1740            .as_any()
1741            .downcast_ref::<Float64Array>()
1742            .unwrap();
1743        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1744    }
1745
1746    #[tokio::test]
1747    async fn test_nor_all_null_values() {
1748        // All-null input → null output
1749        let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1750        let input = make_memory_exec(batch);
1751        let result = execute_fold(
1752            input,
1753            vec![0],
1754            vec![FoldBinding {
1755                output_name: "prob".to_string(),
1756                kind: FoldAggKind::Nor,
1757                input_col_index: 1,
1758            }],
1759        )
1760        .await;
1761        assert_eq!(result.num_rows(), 1);
1762        let vals = result
1763            .column(1)
1764            .as_any()
1765            .downcast_ref::<Float64Array>()
1766            .unwrap();
1767        assert!(vals.is_null(0), "all-null MNOR should produce null");
1768    }
1769
1770    #[tokio::test]
1771    async fn test_prod_all_null_values() {
1772        // All-null input → null output
1773        let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1774        let input = make_memory_exec(batch);
1775        let result = execute_fold(
1776            input,
1777            vec![0],
1778            vec![FoldBinding {
1779                output_name: "prob".to_string(),
1780                kind: FoldAggKind::Prod,
1781                input_col_index: 1,
1782            }],
1783        )
1784        .await;
1785        assert_eq!(result.num_rows(), 1);
1786        let vals = result
1787            .column(1)
1788            .as_any()
1789            .downcast_ref::<Float64Array>()
1790            .unwrap();
1791        assert!(vals.is_null(0), "all-null MPROD should produce null");
1792    }
1793
1794    #[tokio::test]
1795    async fn test_nor_mixed_null_values() {
1796        // Nulls skipped: MNOR({0.3, null, 0.5}) = 1 - (0.7)(0.5) = 0.65
1797        let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.3), None, Some(0.5)]);
1798        let input = make_memory_exec(batch);
1799        let result = execute_fold(
1800            input,
1801            vec![0],
1802            vec![FoldBinding {
1803                output_name: "prob".to_string(),
1804                kind: FoldAggKind::Nor,
1805                input_col_index: 1,
1806            }],
1807        )
1808        .await;
1809        let vals = result
1810            .column(1)
1811            .as_any()
1812            .downcast_ref::<Float64Array>()
1813            .unwrap();
1814        assert!((vals.value(0) - 0.65).abs() < 1e-10);
1815    }
1816
1817    #[tokio::test]
1818    async fn test_prod_mixed_null_values() {
1819        // Nulls skipped: MPROD({0.6, null, 0.8}) = 0.6 * 0.8 = 0.48
1820        let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.6), None, Some(0.8)]);
1821        let input = make_memory_exec(batch);
1822        let result = execute_fold(
1823            input,
1824            vec![0],
1825            vec![FoldBinding {
1826                output_name: "prob".to_string(),
1827                kind: FoldAggKind::Prod,
1828                input_col_index: 1,
1829            }],
1830        )
1831        .await;
1832        let vals = result
1833            .column(1)
1834            .as_any()
1835            .downcast_ref::<Float64Array>()
1836            .unwrap();
1837        assert!((vals.value(0) - 0.48).abs() < 1e-10);
1838    }
1839
1840    #[tokio::test]
1841    async fn test_nor_many_small_values() {
1842        // Large accumulation: 20 × 0.1 → 1 - 0.9^20 ≈ 0.8784
1843        let names: Vec<&str> = vec!["a"; 20];
1844        let values: Vec<f64> = vec![0.1; 20];
1845        let batch = make_test_batch(names, values);
1846        let input = make_memory_exec(batch);
1847        let result = execute_fold(
1848            input,
1849            vec![0],
1850            vec![FoldBinding {
1851                output_name: "prob".to_string(),
1852                kind: FoldAggKind::Nor,
1853                input_col_index: 1,
1854            }],
1855        )
1856        .await;
1857        let vals = result
1858            .column(1)
1859            .as_any()
1860            .downcast_ref::<Float64Array>()
1861            .unwrap();
1862        let expected = 1.0 - 0.9_f64.powi(20);
1863        assert!(
1864            (vals.value(0) - expected).abs() < 1e-10,
1865            "expected {}, got {}",
1866            expected,
1867            vals.value(0)
1868        );
1869    }
1870
1871    // ── FoldAggKind classification tests (Phase 1) ────────────────────────
1872
1873    #[test]
1874    fn test_is_monotonic() {
1875        assert!(FoldAggKind::Sum.is_monotonic());
1876        assert!(FoldAggKind::Max.is_monotonic());
1877        assert!(FoldAggKind::Min.is_monotonic());
1878        assert!(FoldAggKind::Count.is_monotonic());
1879        assert!(FoldAggKind::Nor.is_monotonic());
1880        assert!(FoldAggKind::Prod.is_monotonic());
1881        assert!(!FoldAggKind::Avg.is_monotonic());
1882        assert!(!FoldAggKind::Collect.is_monotonic());
1883    }
1884
1885    #[test]
1886    fn test_monotonicity_direction() {
1887        use super::MonotonicDirection;
1888        assert_eq!(
1889            FoldAggKind::Sum.monotonicity_direction(),
1890            Some(MonotonicDirection::NonDecreasing)
1891        );
1892        assert_eq!(
1893            FoldAggKind::Max.monotonicity_direction(),
1894            Some(MonotonicDirection::NonDecreasing)
1895        );
1896        assert_eq!(
1897            FoldAggKind::Count.monotonicity_direction(),
1898            Some(MonotonicDirection::NonDecreasing)
1899        );
1900        assert_eq!(
1901            FoldAggKind::Nor.monotonicity_direction(),
1902            Some(MonotonicDirection::NonDecreasing)
1903        );
1904        assert_eq!(
1905            FoldAggKind::Min.monotonicity_direction(),
1906            Some(MonotonicDirection::NonIncreasing)
1907        );
1908        assert_eq!(
1909            FoldAggKind::Prod.monotonicity_direction(),
1910            Some(MonotonicDirection::NonIncreasing)
1911        );
1912        assert_eq!(FoldAggKind::Avg.monotonicity_direction(), None);
1913        assert_eq!(FoldAggKind::Collect.monotonicity_direction(), None);
1914    }
1915
1916    #[test]
1917    fn test_identity_values() {
1918        assert_eq!(FoldAggKind::Sum.identity(), Some(0.0));
1919        assert_eq!(FoldAggKind::Count.identity(), Some(0.0));
1920        assert_eq!(FoldAggKind::Nor.identity(), Some(0.0));
1921        assert_eq!(FoldAggKind::Max.identity(), Some(f64::NEG_INFINITY));
1922        assert_eq!(FoldAggKind::Min.identity(), Some(f64::INFINITY));
1923        assert_eq!(FoldAggKind::Prod.identity(), Some(1.0));
1924        assert_eq!(FoldAggKind::Avg.identity(), None);
1925        assert_eq!(FoldAggKind::Collect.identity(), None);
1926    }
1927
1928    // ── Strict mode tests (Phase 5) ──────────────────────────────────────
1929
1930    async fn execute_fold_strict(
1931        input: Arc<dyn ExecutionPlan>,
1932        key_indices: Vec<usize>,
1933        fold_bindings: Vec<FoldBinding>,
1934        strict: bool,
1935    ) -> DFResult<RecordBatch> {
1936        let exec = FoldExec::new(input, key_indices, fold_bindings, strict, 1e-15);
1937        let ctx = SessionContext::new();
1938        let task_ctx = ctx.task_ctx();
1939        let stream = exec.execute(0, task_ctx).unwrap();
1940        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream).await?;
1941        if batches.is_empty() {
1942            Ok(RecordBatch::new_empty(exec.schema()))
1943        } else {
1944            arrow::compute::concat_batches(&exec.schema(), &batches).map_err(arrow_err)
1945        }
1946    }
1947
1948    #[tokio::test]
1949    async fn test_nor_strict_rejects_above_one() {
1950        let batch = make_test_batch(vec!["a"], vec![1.5]);
1951        let input = make_memory_exec(batch);
1952        let result = execute_fold_strict(
1953            input,
1954            vec![0],
1955            vec![FoldBinding {
1956                output_name: "p".into(),
1957                kind: FoldAggKind::Nor,
1958                input_col_index: 1,
1959            }],
1960            true,
1961        )
1962        .await;
1963        assert!(result.is_err());
1964        let err = result.unwrap_err().to_string();
1965        assert!(
1966            err.contains("strict_probability_domain"),
1967            "Expected strict error, got: {}",
1968            err
1969        );
1970    }
1971
1972    #[tokio::test]
1973    async fn test_nor_strict_rejects_negative() {
1974        let batch = make_test_batch(vec!["a"], vec![-0.1]);
1975        let input = make_memory_exec(batch);
1976        let result = execute_fold_strict(
1977            input,
1978            vec![0],
1979            vec![FoldBinding {
1980                output_name: "p".into(),
1981                kind: FoldAggKind::Nor,
1982                input_col_index: 1,
1983            }],
1984            true,
1985        )
1986        .await;
1987        assert!(result.is_err());
1988        let err = result.unwrap_err().to_string();
1989        assert!(
1990            err.contains("strict_probability_domain"),
1991            "Expected strict error, got: {}",
1992            err
1993        );
1994    }
1995
1996    #[tokio::test]
1997    async fn test_prod_strict_rejects_above_one() {
1998        let batch = make_test_batch(vec!["a"], vec![2.0]);
1999        let input = make_memory_exec(batch);
2000        let result = execute_fold_strict(
2001            input,
2002            vec![0],
2003            vec![FoldBinding {
2004                output_name: "p".into(),
2005                kind: FoldAggKind::Prod,
2006                input_col_index: 1,
2007            }],
2008            true,
2009        )
2010        .await;
2011        assert!(result.is_err());
2012        let err = result.unwrap_err().to_string();
2013        assert!(
2014            err.contains("strict_probability_domain"),
2015            "Expected strict error, got: {}",
2016            err
2017        );
2018    }
2019
2020    #[tokio::test]
2021    async fn test_prod_strict_rejects_negative() {
2022        let batch = make_test_batch(vec!["a"], vec![-0.5]);
2023        let input = make_memory_exec(batch);
2024        let result = execute_fold_strict(
2025            input,
2026            vec![0],
2027            vec![FoldBinding {
2028                output_name: "p".into(),
2029                kind: FoldAggKind::Prod,
2030                input_col_index: 1,
2031            }],
2032            true,
2033        )
2034        .await;
2035        assert!(result.is_err());
2036        let err = result.unwrap_err().to_string();
2037        assert!(
2038            err.contains("strict_probability_domain"),
2039            "Expected strict error, got: {}",
2040            err
2041        );
2042    }
2043
2044    #[tokio::test]
2045    async fn test_nor_strict_accepts_valid() {
2046        let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
2047        let input = make_memory_exec(batch);
2048        let result = execute_fold_strict(
2049            input,
2050            vec![0],
2051            vec![FoldBinding {
2052                output_name: "p".into(),
2053                kind: FoldAggKind::Nor,
2054                input_col_index: 1,
2055            }],
2056            true,
2057        )
2058        .await;
2059        assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
2060        let batch = result.unwrap();
2061        let vals = batch
2062            .column(1)
2063            .as_any()
2064            .downcast_ref::<Float64Array>()
2065            .unwrap();
2066        let expected = 0.65; // 1 - (1-0.3)(1-0.5)
2067        assert!(
2068            (vals.value(0) - expected).abs() < 1e-10,
2069            "expected {}, got {}",
2070            expected,
2071            vals.value(0)
2072        );
2073    }
2074
2075    #[tokio::test]
2076    async fn test_count_all_groups_by_key() {
2077        // Two groups: "a" (2 rows), "b" (1 row)
2078        let batch = make_test_batch(vec!["a", "a", "b"], vec![10.0, 20.0, 30.0]);
2079        let input = make_memory_exec(batch);
2080        let result = execute_fold(
2081            input,
2082            vec![0],
2083            vec![FoldBinding {
2084                output_name: "cnt".to_string(),
2085                kind: FoldAggKind::CountAll,
2086                input_col_index: 0, // unused for CountAll
2087            }],
2088        )
2089        .await;
2090
2091        assert_eq!(result.num_rows(), 2, "Should have 2 groups");
2092        let counts = result
2093            .column(1)
2094            .as_any()
2095            .downcast_ref::<Int64Array>()
2096            .unwrap();
2097        assert_eq!(counts.value(0), 2, "Group 'a' should have count 2");
2098        assert_eq!(counts.value(1), 1, "Group 'b' should have count 1");
2099    }
2100}