Skip to main content

uni_query/query/df_graph/
locy_fold.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! FOLD operator for Locy.
5//!
6//! `FoldExec` applies fold (lattice-join) semantics: for each group of rows sharing
7//! the same KEY columns, it reduces non-key columns via their declared fold functions.
8
9use crate::query::df_graph::common::{
10    ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
11};
12use arrow_array::builder::{Float64Builder, Int64Builder, LargeBinaryBuilder};
13use arrow_array::{Array, Float64Array, Int64Array, RecordBatch};
14use arrow_schema::{DataType, Field, Schema, SchemaRef};
15use datafusion::common::Result as DFResult;
16use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
17use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
18use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
19use futures::{Stream, TryStreamExt};
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26
27/// Direction of monotonicity for a fold aggregate.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MonotonicDirection {
30    /// Value can only stay the same or increase across iterations.
31    NonDecreasing,
32    /// Value can only stay the same or decrease across iterations.
33    NonIncreasing,
34}
35
36/// Aggregate function kind for FOLD bindings.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum FoldAggKind {
39    Sum,
40    Max,
41    Min,
42    Count,
43    /// Count all rows in a group (like SQL `COUNT(*)`), ignoring nulls.
44    CountAll,
45    Avg,
46    Collect,
47    Nor,  // Noisy-OR: 1 − ∏(1 − pᵢ)
48    Prod, // Product:  ∏ pᵢ
49}
50
51impl FoldAggKind {
52    /// Returns `true` if this aggregate is monotonic (safe for fixpoint iteration).
53    pub fn is_monotonic(&self) -> bool {
54        matches!(
55            self,
56            Self::Sum
57                | Self::Max
58                | Self::Min
59                | Self::Count
60                | Self::CountAll
61                | Self::Nor
62                | Self::Prod
63        )
64    }
65
66    /// Returns the monotonicity direction, or `None` for non-monotonic aggregates.
67    pub fn monotonicity_direction(&self) -> Option<MonotonicDirection> {
68        match self {
69            Self::Sum | Self::Max | Self::Count | Self::CountAll | Self::Nor => {
70                Some(MonotonicDirection::NonDecreasing)
71            }
72            Self::Min | Self::Prod => Some(MonotonicDirection::NonIncreasing),
73            Self::Avg | Self::Collect => None,
74        }
75    }
76
77    /// Returns the identity element for this aggregate, or `None` for non-monotonic aggregates.
78    pub fn identity(&self) -> Option<f64> {
79        match self {
80            Self::Sum | Self::Count | Self::CountAll | Self::Nor => Some(0.0),
81            Self::Max => Some(f64::NEG_INFINITY),
82            Self::Min => Some(f64::INFINITY),
83            Self::Prod => Some(1.0),
84            Self::Avg | Self::Collect => None,
85        }
86    }
87}
88
89/// A single FOLD binding: aggregate an input column into an output column.
90#[derive(Debug, Clone)]
91pub struct FoldBinding {
92    pub output_name: String,
93    pub kind: FoldAggKind,
94    pub input_col_index: usize,
95    /// Column name for name-based resolution (more robust than positional index).
96    /// `None` for CountAll which has no input column.
97    pub input_col_name: Option<String>,
98}
99
100/// DataFusion `ExecutionPlan` that applies FOLD semantics.
101///
102/// Groups rows by KEY columns and computes aggregates (SUM, MAX, MIN, COUNT, AVG, COLLECT)
103/// for each fold binding. Output schema is KEY columns + fold output columns.
104#[derive(Debug)]
105pub struct FoldExec {
106    input: Arc<dyn ExecutionPlan>,
107    key_indices: Vec<usize>,
108    fold_bindings: Vec<FoldBinding>,
109    strict_probability_domain: bool,
110    probability_epsilon: f64,
111    schema: SchemaRef,
112    properties: PlanProperties,
113    metrics: ExecutionPlanMetricsSet,
114}
115
116impl FoldExec {
117    /// Create a new `FoldExec`.
118    ///
119    /// # Arguments
120    /// * `input` - Child execution plan
121    /// * `key_indices` - Indices of KEY columns for grouping
122    /// * `fold_bindings` - Aggregate bindings (output name, kind, input col index)
123    pub fn new(
124        input: Arc<dyn ExecutionPlan>,
125        key_indices: Vec<usize>,
126        fold_bindings: Vec<FoldBinding>,
127        strict_probability_domain: bool,
128        probability_epsilon: f64,
129    ) -> Self {
130        let input_schema = input.schema();
131        let schema = Self::build_output_schema(&input_schema, &key_indices, &fold_bindings);
132        let properties = compute_plan_properties(Arc::clone(&schema));
133
134        Self {
135            input,
136            key_indices,
137            fold_bindings,
138            strict_probability_domain,
139            probability_epsilon,
140            schema,
141            properties,
142            metrics: ExecutionPlanMetricsSet::new(),
143        }
144    }
145
146    fn build_output_schema(
147        input_schema: &SchemaRef,
148        key_indices: &[usize],
149        fold_bindings: &[FoldBinding],
150    ) -> SchemaRef {
151        let mut fields = Vec::new();
152
153        // Key columns preserve original types
154        for &ki in key_indices {
155            fields.push(Arc::new(input_schema.field(ki).clone()));
156        }
157
158        // Fold output columns
159        for binding in fold_bindings {
160            let output_type = match binding.kind {
161                FoldAggKind::Sum | FoldAggKind::Avg | FoldAggKind::Nor | FoldAggKind::Prod => {
162                    DataType::Float64
163                }
164                FoldAggKind::Count | FoldAggKind::CountAll => DataType::Int64,
165                FoldAggKind::Max | FoldAggKind::Min => {
166                    let idx = binding
167                        .input_col_name
168                        .as_ref()
169                        .and_then(|name| input_schema.index_of(name).ok())
170                        .unwrap_or(binding.input_col_index);
171                    if idx < input_schema.fields().len() {
172                        input_schema.field(idx).data_type().clone()
173                    } else {
174                        DataType::Float64
175                    }
176                }
177                FoldAggKind::Collect => DataType::LargeBinary,
178            };
179            fields.push(Arc::new(Field::new(
180                &binding.output_name,
181                output_type,
182                true,
183            )));
184        }
185
186        Arc::new(Schema::new(fields))
187    }
188}
189
190impl DisplayAs for FoldExec {
191    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192        write!(
193            f,
194            "FoldExec: key_indices={:?}, bindings={:?}",
195            self.key_indices, self.fold_bindings
196        )
197    }
198}
199
200impl ExecutionPlan for FoldExec {
201    fn name(&self) -> &str {
202        "FoldExec"
203    }
204
205    fn as_any(&self) -> &dyn Any {
206        self
207    }
208
209    fn schema(&self) -> SchemaRef {
210        Arc::clone(&self.schema)
211    }
212
213    fn properties(&self) -> &PlanProperties {
214        &self.properties
215    }
216
217    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
218        vec![&self.input]
219    }
220
221    fn with_new_children(
222        self: Arc<Self>,
223        children: Vec<Arc<dyn ExecutionPlan>>,
224    ) -> DFResult<Arc<dyn ExecutionPlan>> {
225        if children.len() != 1 {
226            return Err(datafusion::error::DataFusionError::Plan(
227                "FoldExec requires exactly one child".to_string(),
228            ));
229        }
230        Ok(Arc::new(Self::new(
231            Arc::clone(&children[0]),
232            self.key_indices.clone(),
233            self.fold_bindings.clone(),
234            self.strict_probability_domain,
235            self.probability_epsilon,
236        )))
237    }
238
239    fn execute(
240        &self,
241        partition: usize,
242        context: Arc<TaskContext>,
243    ) -> DFResult<SendableRecordBatchStream> {
244        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
245        let metrics = BaselineMetrics::new(&self.metrics, partition);
246        let key_indices = self.key_indices.clone();
247        let fold_bindings = self.fold_bindings.clone();
248        let strict = self.strict_probability_domain;
249        let epsilon = self.probability_epsilon;
250        let output_schema = Arc::clone(&self.schema);
251        let input_schema = self.input.schema();
252
253        let fut = async move {
254            let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
255
256            if batches.is_empty() {
257                return Ok(RecordBatch::new_empty(output_schema));
258            }
259
260            // Use the actual batch schema (may differ from pre-computed input_schema
261            // after schema reconciliation in schemaless mode).
262            let actual_schema = batches
263                .first()
264                .map(|b| b.schema())
265                .unwrap_or(input_schema.clone());
266            let batch =
267                arrow::compute::concat_batches(&actual_schema, &batches).map_err(arrow_err)?;
268
269            if batch.num_rows() == 0 {
270                return Ok(RecordBatch::new_empty(output_schema));
271            }
272
273            // Group by key columns → row indices, preserving insertion order
274            let mut groups: HashMap<Vec<ScalarKey>, Vec<usize>> = HashMap::new();
275            let mut ordered_keys: Vec<Vec<ScalarKey>> = Vec::new();
276            for row_idx in 0..batch.num_rows() {
277                let key = extract_scalar_key(&batch, &key_indices, row_idx);
278                let entry = groups.entry(key.clone());
279                if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
280                    ordered_keys.push(key);
281                }
282                entry.or_default().push(row_idx);
283            }
284
285            let num_groups = ordered_keys.len();
286
287            // Build output columns
288            let mut output_columns: Vec<arrow_array::ArrayRef> = Vec::new();
289
290            // Key columns: take from first row of each group
291            for &ki in &key_indices {
292                if ki >= batch.num_columns() {
293                    continue; // Skip invalid indices after schema reconciliation
294                }
295                let col = batch.column(ki);
296                let first_indices: Vec<u32> =
297                    ordered_keys.iter().map(|k| groups[k][0] as u32).collect();
298                let idx_array = arrow_array::UInt32Array::from(first_indices);
299                let taken = arrow::compute::take(col.as_ref(), &idx_array, None).map_err(|e| {
300                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
301                })?;
302                output_columns.push(taken);
303            }
304
305            // Fold binding columns: compute aggregates per group
306            for binding in &fold_bindings {
307                let col: Arc<dyn Array> = if binding.kind == FoldAggKind::CountAll {
308                    // CountAll doesn't need an input column — use a dummy
309                    Arc::new(arrow_array::Int64Array::from(vec![0i64; batch.num_rows()]))
310                } else {
311                    // Resolve input column: prefer name-based lookup, fall back to index.
312                    let resolved_idx = binding
313                        .input_col_name
314                        .as_ref()
315                        .and_then(|name| batch.schema().index_of(name).ok())
316                        .unwrap_or(binding.input_col_index);
317                    if resolved_idx < batch.num_columns() {
318                        Arc::clone(batch.column(resolved_idx))
319                    } else {
320                        // Column not found — use zeros as fallback
321                        Arc::new(arrow_array::Float64Array::from(vec![
322                            0.0f64;
323                            batch.num_rows()
324                        ]))
325                    }
326                };
327                let agg_col = compute_fold_aggregate(
328                    col.as_ref(),
329                    &binding.kind,
330                    &ordered_keys,
331                    &groups,
332                    num_groups,
333                    strict,
334                    epsilon,
335                )?;
336                output_columns.push(agg_col);
337            }
338
339            RecordBatch::try_new(output_schema, output_columns).map_err(arrow_err)
340        };
341
342        Ok(Box::pin(FoldStream {
343            state: FoldStreamState::Running(Box::pin(fut)),
344            schema: Arc::clone(&self.schema),
345            metrics,
346        }))
347    }
348
349    fn metrics(&self) -> Option<MetricsSet> {
350        Some(self.metrics.clone_inner())
351    }
352}
353
354// ---------------------------------------------------------------------------
355// Aggregate computation
356// ---------------------------------------------------------------------------
357
358fn compute_fold_aggregate(
359    col: &dyn Array,
360    kind: &FoldAggKind,
361    ordered_keys: &[Vec<ScalarKey>],
362    groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
363    num_groups: usize,
364    strict: bool,
365    probability_epsilon: f64,
366) -> DFResult<arrow_array::ArrayRef> {
367    match kind {
368        FoldAggKind::Sum => {
369            let mut builder = Float64Builder::with_capacity(num_groups);
370            for key in ordered_keys {
371                builder.append_option(sum_f64(col, &groups[key]));
372            }
373            Ok(Arc::new(builder.finish()))
374        }
375        FoldAggKind::Count => {
376            let mut builder = Int64Builder::with_capacity(num_groups);
377            for key in ordered_keys {
378                let indices = &groups[key];
379                let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
380                builder.append_value(count as i64);
381            }
382            Ok(Arc::new(builder.finish()))
383        }
384        FoldAggKind::CountAll => {
385            let mut builder = Int64Builder::with_capacity(num_groups);
386            for key in ordered_keys {
387                let indices = &groups[key];
388                builder.append_value(indices.len() as i64);
389            }
390            Ok(Arc::new(builder.finish()))
391        }
392        FoldAggKind::Max => compute_minmax(col, ordered_keys, groups, num_groups, false),
393        FoldAggKind::Min => compute_minmax(col, ordered_keys, groups, num_groups, true),
394        FoldAggKind::Avg => {
395            let mut builder = Float64Builder::with_capacity(num_groups);
396            for key in ordered_keys {
397                let indices = &groups[key];
398                let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
399                let avg = sum_f64(col, indices)
400                    .filter(|_| count > 0)
401                    .map(|s| s / count as f64);
402                builder.append_option(avg);
403            }
404            Ok(Arc::new(builder.finish()))
405        }
406        FoldAggKind::Collect => {
407            let mut builder = LargeBinaryBuilder::with_capacity(num_groups, num_groups * 32);
408            for key in ordered_keys {
409                let values: Vec<uni_common::Value> = groups[key]
410                    .iter()
411                    .filter(|&&i| !col.is_null(i))
412                    .map(|&i| scalar_to_value(col, i))
413                    .collect();
414                let encoded =
415                    uni_common::cypher_value_codec::encode(&uni_common::Value::List(values));
416                builder.append_value(&encoded);
417            }
418            Ok(Arc::new(builder.finish()))
419        }
420        FoldAggKind::Nor => {
421            let mut builder = Float64Builder::with_capacity(num_groups);
422            for key in ordered_keys {
423                let indices = &groups[key];
424                builder.append_option(noisy_or_f64(col, indices, strict)?);
425            }
426            Ok(Arc::new(builder.finish()))
427        }
428        FoldAggKind::Prod => {
429            let mut builder = Float64Builder::with_capacity(num_groups);
430            for key in ordered_keys {
431                builder.append_option(product_f64(col, &groups[key], strict, probability_epsilon)?);
432            }
433            Ok(Arc::new(builder.finish()))
434        }
435    }
436}
437
438fn sum_f64(col: &dyn Array, indices: &[usize]) -> Option<f64> {
439    let mut sum = 0.0;
440    let mut has_value = false;
441    for &i in indices {
442        if col.is_null(i) {
443            continue;
444        }
445        has_value = true;
446        if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
447            sum += arr.value(i);
448        } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
449            sum += arr.value(i) as f64;
450        }
451    }
452    if has_value { Some(sum) } else { None }
453}
454
455/// Noisy-OR: P = 1 − ∏(1 − pᵢ). Inputs clamped to [0, 1] unless strict.
456fn noisy_or_f64(col: &dyn Array, indices: &[usize], strict: bool) -> DFResult<Option<f64>> {
457    let mut complement_product = 1.0;
458    let mut has_value = false;
459    for &i in indices {
460        if col.is_null(i) {
461            continue;
462        }
463        has_value = true;
464        let raw = if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
465            arr.value(i)
466        } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
467            arr.value(i) as f64
468        } else {
469            continue;
470        };
471        if strict && !(0.0..=1.0).contains(&raw) {
472            return Err(datafusion::error::DataFusionError::Execution(format!(
473                "strict_probability_domain: MNOR input {raw} is outside [0, 1]"
474            )));
475        }
476        if !strict && !(0.0..=1.0).contains(&raw) {
477            tracing::warn!(
478                "MNOR input {raw} outside [0,1], clamped to {}",
479                raw.clamp(0.0, 1.0)
480            );
481        }
482        let p = raw.clamp(0.0, 1.0);
483        complement_product *= 1.0 - p;
484    }
485    if has_value {
486        Ok(Some(1.0 - complement_product))
487    } else {
488        Ok(None)
489    }
490}
491
492/// Product: P = ∏ pᵢ. Inputs clamped to [0, 1] unless strict.
493///
494/// Switches to log-space when the running product drops below
495/// `probability_epsilon` to prevent floating-point underflow.
496fn product_f64(
497    col: &dyn Array,
498    indices: &[usize],
499    strict: bool,
500    probability_epsilon: f64,
501) -> DFResult<Option<f64>> {
502    let mut product = 1.0;
503    let mut log_sum = 0.0;
504    let mut use_log = false;
505    let mut has_value = false;
506    for &i in indices {
507        if col.is_null(i) {
508            continue;
509        }
510        has_value = true;
511        let raw = if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
512            arr.value(i)
513        } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
514            arr.value(i) as f64
515        } else {
516            continue;
517        };
518        if strict && !(0.0..=1.0).contains(&raw) {
519            return Err(datafusion::error::DataFusionError::Execution(format!(
520                "strict_probability_domain: MPROD input {raw} is outside [0, 1]"
521            )));
522        }
523        if !strict && !(0.0..=1.0).contains(&raw) {
524            tracing::warn!(
525                "MPROD input {raw} outside [0,1], clamped to {}",
526                raw.clamp(0.0, 1.0)
527            );
528        }
529        let p = raw.clamp(0.0, 1.0);
530        if p == 0.0 {
531            return Ok(Some(0.0));
532        }
533        if use_log {
534            log_sum += p.ln();
535        } else {
536            product *= p;
537            if product < probability_epsilon {
538                // Switch to log-space to prevent underflow
539                log_sum = product.ln();
540                use_log = true;
541            }
542        }
543    }
544    if !has_value {
545        return Ok(None);
546    }
547    if use_log {
548        Ok(Some(log_sum.exp()))
549    } else {
550        Ok(Some(product))
551    }
552}
553
554fn compute_minmax(
555    col: &dyn Array,
556    ordered_keys: &[Vec<ScalarKey>],
557    groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
558    num_groups: usize,
559    is_min: bool,
560) -> DFResult<arrow_array::ArrayRef> {
561    match col.data_type() {
562        DataType::Int64 => {
563            let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
564            let mut builder = Int64Builder::with_capacity(num_groups);
565            for key in ordered_keys {
566                let mut result: Option<i64> = None;
567                for &i in &groups[key] {
568                    if !arr.is_null(i) {
569                        let v = arr.value(i);
570                        result = Some(match result {
571                            None => v,
572                            Some(cur) if is_min => cur.min(v),
573                            Some(cur) => cur.max(v),
574                        });
575                    }
576                }
577                builder.append_option(result);
578            }
579            Ok(Arc::new(builder.finish()))
580        }
581        DataType::Float64 => {
582            let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
583            let mut builder = Float64Builder::with_capacity(num_groups);
584            for key in ordered_keys {
585                let mut result: Option<f64> = None;
586                for &i in &groups[key] {
587                    if !arr.is_null(i) {
588                        let v = arr.value(i);
589                        result = Some(match result {
590                            None => v,
591                            Some(cur) if is_min => cur.min(v),
592                            Some(cur) => cur.max(v),
593                        });
594                    }
595                }
596                builder.append_option(result);
597            }
598            Ok(Arc::new(builder.finish()))
599        }
600        dt => {
601            // Fallback: treat as string comparison.
602            // Use LargeStringBuilder for LargeUtf8 input to match the output schema
603            // (build_output_schema preserves the input type for MAX/MIN).
604            let use_large = matches!(dt, DataType::LargeUtf8);
605            let mut values: Vec<Option<String>> = Vec::with_capacity(num_groups);
606            for key in ordered_keys {
607                let indices = &groups[key];
608                let mut result: Option<String> = None;
609                for &i in indices {
610                    if col.is_null(i) {
611                        continue;
612                    }
613                    let v = format!("{:?}", scalar_to_value(col, i));
614                    result = Some(match result {
615                        None => v,
616                        Some(cur) if is_min && v < cur => v,
617                        Some(cur) if !is_min && v > cur => v,
618                        Some(cur) => cur,
619                    });
620                }
621                values.push(result);
622            }
623            Ok(build_optional_string_array(&values, use_large))
624        }
625    }
626}
627
628fn build_optional_string_array(
629    values: &[Option<String>],
630    use_large: bool,
631) -> arrow_array::ArrayRef {
632    if use_large {
633        let mut builder = arrow_array::builder::LargeStringBuilder::new();
634        for v in values {
635            match v {
636                Some(s) => builder.append_value(s),
637                None => builder.append_null(),
638            }
639        }
640        Arc::new(builder.finish())
641    } else {
642        let mut builder = arrow_array::builder::StringBuilder::new();
643        for v in values {
644            match v {
645                Some(s) => builder.append_value(s),
646                None => builder.append_null(),
647            }
648        }
649        Arc::new(builder.finish())
650    }
651}
652
653fn scalar_to_value(col: &dyn Array, row_idx: usize) -> uni_common::Value {
654    if col.is_null(row_idx) {
655        return uni_common::Value::Null;
656    }
657    match col.data_type() {
658        DataType::Int64 => {
659            let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
660            uni_common::Value::Int(arr.value(row_idx))
661        }
662        DataType::Float64 => {
663            let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
664            uni_common::Value::Float(arr.value(row_idx))
665        }
666        DataType::Utf8 => {
667            let arr = col
668                .as_any()
669                .downcast_ref::<arrow_array::StringArray>()
670                .unwrap();
671            uni_common::Value::String(arr.value(row_idx).to_string())
672        }
673        DataType::LargeUtf8 => {
674            let arr = col
675                .as_any()
676                .downcast_ref::<arrow_array::LargeStringArray>()
677                .unwrap();
678            uni_common::Value::String(arr.value(row_idx).to_string())
679        }
680        DataType::Boolean => {
681            let arr = col
682                .as_any()
683                .downcast_ref::<arrow_array::BooleanArray>()
684                .unwrap();
685            uni_common::Value::Bool(arr.value(row_idx))
686        }
687        DataType::LargeBinary => {
688            let arr = col
689                .as_any()
690                .downcast_ref::<arrow_array::LargeBinaryArray>()
691                .unwrap();
692            let bytes = arr.value(row_idx);
693            uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null)
694        }
695        _ => uni_common::Value::Null,
696    }
697}
698
699// ---------------------------------------------------------------------------
700// Stream implementation
701// ---------------------------------------------------------------------------
702
703enum FoldStreamState {
704    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
705    Done,
706}
707
708struct FoldStream {
709    state: FoldStreamState,
710    schema: SchemaRef,
711    metrics: BaselineMetrics,
712}
713
714impl Stream for FoldStream {
715    type Item = DFResult<RecordBatch>;
716
717    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
718        match &mut self.state {
719            FoldStreamState::Running(fut) => match fut.as_mut().poll(cx) {
720                Poll::Ready(Ok(batch)) => {
721                    self.metrics.record_output(batch.num_rows());
722                    self.state = FoldStreamState::Done;
723                    Poll::Ready(Some(Ok(batch)))
724                }
725                Poll::Ready(Err(e)) => {
726                    self.state = FoldStreamState::Done;
727                    Poll::Ready(Some(Err(e)))
728                }
729                Poll::Pending => Poll::Pending,
730            },
731            FoldStreamState::Done => Poll::Ready(None),
732        }
733    }
734}
735
736impl RecordBatchStream for FoldStream {
737    fn schema(&self) -> SchemaRef {
738        Arc::clone(&self.schema)
739    }
740}
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745    use arrow_array::{Float64Array, Int64Array, StringArray};
746    use arrow_schema::{DataType, Field, Schema};
747    use datafusion::physical_plan::memory::MemoryStream;
748    use datafusion::prelude::SessionContext;
749
750    fn make_test_batch(names: Vec<&str>, values: Vec<f64>) -> RecordBatch {
751        let schema = Arc::new(Schema::new(vec![
752            Field::new("name", DataType::Utf8, true),
753            Field::new("value", DataType::Float64, true),
754        ]));
755        RecordBatch::try_new(
756            schema,
757            vec![
758                Arc::new(StringArray::from(
759                    names.into_iter().map(Some).collect::<Vec<_>>(),
760                )),
761                Arc::new(Float64Array::from(values)),
762            ],
763        )
764        .unwrap()
765    }
766
767    fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
768        let schema = batch.schema();
769        Arc::new(TestMemoryExec {
770            batches: vec![batch],
771            schema: schema.clone(),
772            properties: compute_plan_properties(schema),
773        })
774    }
775
776    #[derive(Debug)]
777    struct TestMemoryExec {
778        batches: Vec<RecordBatch>,
779        schema: SchemaRef,
780        properties: PlanProperties,
781    }
782
783    impl DisplayAs for TestMemoryExec {
784        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
785            write!(f, "TestMemoryExec")
786        }
787    }
788
789    impl ExecutionPlan for TestMemoryExec {
790        fn name(&self) -> &str {
791            "TestMemoryExec"
792        }
793        fn as_any(&self) -> &dyn Any {
794            self
795        }
796        fn schema(&self) -> SchemaRef {
797            Arc::clone(&self.schema)
798        }
799        fn properties(&self) -> &PlanProperties {
800            &self.properties
801        }
802        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
803            vec![]
804        }
805        fn with_new_children(
806            self: Arc<Self>,
807            _children: Vec<Arc<dyn ExecutionPlan>>,
808        ) -> DFResult<Arc<dyn ExecutionPlan>> {
809            Ok(self)
810        }
811        fn execute(
812            &self,
813            _partition: usize,
814            _context: Arc<TaskContext>,
815        ) -> DFResult<SendableRecordBatchStream> {
816            Ok(Box::pin(MemoryStream::try_new(
817                self.batches.clone(),
818                Arc::clone(&self.schema),
819                None,
820            )?))
821        }
822    }
823
824    async fn execute_fold(
825        input: Arc<dyn ExecutionPlan>,
826        key_indices: Vec<usize>,
827        fold_bindings: Vec<FoldBinding>,
828    ) -> RecordBatch {
829        let exec = FoldExec::new(input, key_indices, fold_bindings, false, 1e-15);
830        let ctx = SessionContext::new();
831        let task_ctx = ctx.task_ctx();
832        let stream = exec.execute(0, task_ctx).unwrap();
833        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
834            .await
835            .unwrap();
836        if batches.is_empty() {
837            RecordBatch::new_empty(exec.schema())
838        } else {
839            arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
840        }
841    }
842
843    #[tokio::test]
844    async fn test_sum_single_group() {
845        let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
846        let input = make_memory_exec(batch);
847        let result = execute_fold(
848            input,
849            vec![0],
850            vec![FoldBinding {
851                output_name: "total".to_string(),
852                kind: FoldAggKind::Sum,
853                input_col_index: 1,
854                input_col_name: None,
855            }],
856        )
857        .await;
858
859        assert_eq!(result.num_rows(), 1);
860        let totals = result
861            .column(1)
862            .as_any()
863            .downcast_ref::<Float64Array>()
864            .unwrap();
865        assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
866    }
867
868    #[tokio::test]
869    async fn test_count_non_null() {
870        let schema = Arc::new(Schema::new(vec![
871            Field::new("name", DataType::Utf8, true),
872            Field::new("value", DataType::Float64, true),
873        ]));
874        let batch = RecordBatch::try_new(
875            schema,
876            vec![
877                Arc::new(StringArray::from(vec![Some("a"), Some("a"), Some("a")])),
878                Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])),
879            ],
880        )
881        .unwrap();
882        let input = make_memory_exec(batch);
883        let result = execute_fold(
884            input,
885            vec![0],
886            vec![FoldBinding {
887                output_name: "cnt".to_string(),
888                kind: FoldAggKind::Count,
889                input_col_index: 1,
890                input_col_name: None,
891            }],
892        )
893        .await;
894
895        assert_eq!(result.num_rows(), 1);
896        let counts = result
897            .column(1)
898            .as_any()
899            .downcast_ref::<Int64Array>()
900            .unwrap();
901        assert_eq!(counts.value(0), 2); // null not counted
902    }
903
904    #[tokio::test]
905    async fn test_max_min() {
906        let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 5.0]);
907        let input_max = make_memory_exec(batch.clone());
908        let input_min = make_memory_exec(batch);
909
910        let result_max = execute_fold(
911            input_max,
912            vec![0],
913            vec![FoldBinding {
914                output_name: "mx".to_string(),
915                kind: FoldAggKind::Max,
916                input_col_index: 1,
917                input_col_name: None,
918            }],
919        )
920        .await;
921        let result_min = execute_fold(
922            input_min,
923            vec![0],
924            vec![FoldBinding {
925                output_name: "mn".to_string(),
926                kind: FoldAggKind::Min,
927                input_col_index: 1,
928                input_col_name: None,
929            }],
930        )
931        .await;
932
933        let max_vals = result_max
934            .column(1)
935            .as_any()
936            .downcast_ref::<Float64Array>()
937            .unwrap();
938        assert_eq!(max_vals.value(0), 5.0);
939
940        let min_vals = result_min
941            .column(1)
942            .as_any()
943            .downcast_ref::<Float64Array>()
944            .unwrap();
945        assert_eq!(min_vals.value(0), 1.0);
946    }
947
948    #[tokio::test]
949    async fn test_avg() {
950        let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![2.0, 4.0, 6.0, 8.0]);
951        let input = make_memory_exec(batch);
952        let result = execute_fold(
953            input,
954            vec![0],
955            vec![FoldBinding {
956                output_name: "average".to_string(),
957                kind: FoldAggKind::Avg,
958                input_col_index: 1,
959                input_col_name: None,
960            }],
961        )
962        .await;
963
964        assert_eq!(result.num_rows(), 1);
965        let avgs = result
966            .column(1)
967            .as_any()
968            .downcast_ref::<Float64Array>()
969            .unwrap();
970        assert!((avgs.value(0) - 5.0).abs() < f64::EPSILON);
971    }
972
973    #[tokio::test]
974    async fn test_multiple_groups() {
975        let batch = make_test_batch(
976            vec!["a", "a", "b", "b", "b"],
977            vec![1.0, 2.0, 10.0, 20.0, 30.0],
978        );
979        let input = make_memory_exec(batch);
980        let result = execute_fold(
981            input,
982            vec![0],
983            vec![FoldBinding {
984                output_name: "total".to_string(),
985                kind: FoldAggKind::Sum,
986                input_col_index: 1,
987                input_col_name: None,
988            }],
989        )
990        .await;
991
992        assert_eq!(result.num_rows(), 2);
993        let names = result
994            .column(0)
995            .as_any()
996            .downcast_ref::<StringArray>()
997            .unwrap();
998        let totals = result
999            .column(1)
1000            .as_any()
1001            .downcast_ref::<Float64Array>()
1002            .unwrap();
1003
1004        for i in 0..2 {
1005            match names.value(i) {
1006                "a" => assert!((totals.value(i) - 3.0).abs() < f64::EPSILON),
1007                "b" => assert!((totals.value(i) - 60.0).abs() < f64::EPSILON),
1008                _ => panic!("unexpected name"),
1009            }
1010        }
1011    }
1012
1013    #[tokio::test]
1014    async fn test_empty_input() {
1015        let schema = Arc::new(Schema::new(vec![
1016            Field::new("name", DataType::Utf8, true),
1017            Field::new("value", DataType::Float64, true),
1018        ]));
1019        let batch = RecordBatch::new_empty(schema);
1020        let input = make_memory_exec(batch);
1021        let result = execute_fold(
1022            input,
1023            vec![0],
1024            vec![FoldBinding {
1025                output_name: "total".to_string(),
1026                kind: FoldAggKind::Sum,
1027                input_col_index: 1,
1028                input_col_name: None,
1029            }],
1030        )
1031        .await;
1032
1033        assert_eq!(result.num_rows(), 0);
1034    }
1035
1036    #[tokio::test]
1037    async fn test_multiple_bindings() {
1038        let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
1039        let input = make_memory_exec(batch);
1040        let result = execute_fold(
1041            input,
1042            vec![0],
1043            vec![
1044                FoldBinding {
1045                    output_name: "total".to_string(),
1046                    kind: FoldAggKind::Sum,
1047                    input_col_index: 1,
1048                    input_col_name: None,
1049                },
1050                FoldBinding {
1051                    output_name: "cnt".to_string(),
1052                    kind: FoldAggKind::Count,
1053                    input_col_index: 1,
1054                    input_col_name: None,
1055                },
1056                FoldBinding {
1057                    output_name: "mx".to_string(),
1058                    kind: FoldAggKind::Max,
1059                    input_col_index: 1,
1060                    input_col_name: None,
1061                },
1062            ],
1063        )
1064        .await;
1065
1066        assert_eq!(result.num_rows(), 1);
1067        assert_eq!(result.num_columns(), 4); // name + total + cnt + mx
1068
1069        let totals = result
1070            .column(1)
1071            .as_any()
1072            .downcast_ref::<Float64Array>()
1073            .unwrap();
1074        assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
1075
1076        let counts = result
1077            .column(2)
1078            .as_any()
1079            .downcast_ref::<Int64Array>()
1080            .unwrap();
1081        assert_eq!(counts.value(0), 3);
1082
1083        let maxes = result
1084            .column(3)
1085            .as_any()
1086            .downcast_ref::<Float64Array>()
1087            .unwrap();
1088        assert_eq!(maxes.value(0), 3.0);
1089    }
1090
1091    // ── MNOR tests ───────────────────────────────────────────────────────
1092
1093    #[tokio::test]
1094    async fn test_nor_single_group() {
1095        // MNOR({0.3, 0.5}) = 1 - (1-0.3)*(1-0.5) = 1 - 0.7*0.5 = 1 - 0.35 = 0.65
1096        let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
1097        let input = make_memory_exec(batch);
1098        let result = execute_fold(
1099            input,
1100            vec![0],
1101            vec![FoldBinding {
1102                output_name: "prob".to_string(),
1103                kind: FoldAggKind::Nor,
1104                input_col_index: 1,
1105                input_col_name: None,
1106            }],
1107        )
1108        .await;
1109
1110        assert_eq!(result.num_rows(), 1);
1111        let vals = result
1112            .column(1)
1113            .as_any()
1114            .downcast_ref::<Float64Array>()
1115            .unwrap();
1116        assert!((vals.value(0) - 0.65).abs() < 1e-10);
1117    }
1118
1119    #[tokio::test]
1120    async fn test_nor_identity() {
1121        // MNOR({0.0, 0.0}) = 1 - (1-0)*(1-0) = 1 - 1 = 0.0
1122        let batch = make_test_batch(vec!["a", "a"], vec![0.0, 0.0]);
1123        let input = make_memory_exec(batch);
1124        let result = execute_fold(
1125            input,
1126            vec![0],
1127            vec![FoldBinding {
1128                output_name: "prob".to_string(),
1129                kind: FoldAggKind::Nor,
1130                input_col_index: 1,
1131                input_col_name: None,
1132            }],
1133        )
1134        .await;
1135
1136        let vals = result
1137            .column(1)
1138            .as_any()
1139            .downcast_ref::<Float64Array>()
1140            .unwrap();
1141        assert!((vals.value(0) - 0.0).abs() < 1e-10);
1142    }
1143
1144    #[tokio::test]
1145    async fn test_nor_clamping() {
1146        // Out-of-range values should be clamped to [0, 1]
1147        let batch = make_test_batch(vec!["a", "a"], vec![-0.5, 1.5]);
1148        let input = make_memory_exec(batch);
1149        let result = execute_fold(
1150            input,
1151            vec![0],
1152            vec![FoldBinding {
1153                output_name: "prob".to_string(),
1154                kind: FoldAggKind::Nor,
1155                input_col_index: 1,
1156                input_col_name: None,
1157            }],
1158        )
1159        .await;
1160
1161        let vals = result
1162            .column(1)
1163            .as_any()
1164            .downcast_ref::<Float64Array>()
1165            .unwrap();
1166        // Clamped to (0.0, 1.0): MNOR = 1 - (1-0)*(1-1) = 1 - 1*0 = 1.0
1167        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1168    }
1169
1170    #[tokio::test]
1171    async fn test_nor_multiple_groups() {
1172        let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.3, 0.5, 0.1, 0.2]);
1173        let input = make_memory_exec(batch);
1174        let result = execute_fold(
1175            input,
1176            vec![0],
1177            vec![FoldBinding {
1178                output_name: "prob".to_string(),
1179                kind: FoldAggKind::Nor,
1180                input_col_index: 1,
1181                input_col_name: None,
1182            }],
1183        )
1184        .await;
1185
1186        assert_eq!(result.num_rows(), 2);
1187        let names = result
1188            .column(0)
1189            .as_any()
1190            .downcast_ref::<StringArray>()
1191            .unwrap();
1192        let vals = result
1193            .column(1)
1194            .as_any()
1195            .downcast_ref::<Float64Array>()
1196            .unwrap();
1197
1198        for i in 0..2 {
1199            match names.value(i) {
1200                // MNOR({0.3, 0.5}) = 0.65
1201                "a" => assert!((vals.value(i) - 0.65).abs() < 1e-10),
1202                // MNOR({0.1, 0.2}) = 1 - 0.9*0.8 = 1 - 0.72 = 0.28
1203                "b" => assert!((vals.value(i) - 0.28).abs() < 1e-10),
1204                _ => panic!("unexpected name"),
1205            }
1206        }
1207    }
1208
1209    // ── MPROD tests ──────────────────────────────────────────────────────
1210
1211    #[tokio::test]
1212    async fn test_prod_single_group() {
1213        // MPROD({0.6, 0.8}) = 0.48
1214        let batch = make_test_batch(vec!["a", "a"], vec![0.6, 0.8]);
1215        let input = make_memory_exec(batch);
1216        let result = execute_fold(
1217            input,
1218            vec![0],
1219            vec![FoldBinding {
1220                output_name: "prob".to_string(),
1221                kind: FoldAggKind::Prod,
1222                input_col_index: 1,
1223                input_col_name: None,
1224            }],
1225        )
1226        .await;
1227
1228        assert_eq!(result.num_rows(), 1);
1229        let vals = result
1230            .column(1)
1231            .as_any()
1232            .downcast_ref::<Float64Array>()
1233            .unwrap();
1234        assert!((vals.value(0) - 0.48).abs() < 1e-10);
1235    }
1236
1237    #[tokio::test]
1238    async fn test_prod_identity() {
1239        // MPROD({1.0, 1.0}) = 1.0
1240        let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0]);
1241        let input = make_memory_exec(batch);
1242        let result = execute_fold(
1243            input,
1244            vec![0],
1245            vec![FoldBinding {
1246                output_name: "prob".to_string(),
1247                kind: FoldAggKind::Prod,
1248                input_col_index: 1,
1249                input_col_name: None,
1250            }],
1251        )
1252        .await;
1253
1254        let vals = result
1255            .column(1)
1256            .as_any()
1257            .downcast_ref::<Float64Array>()
1258            .unwrap();
1259        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1260    }
1261
1262    #[tokio::test]
1263    async fn test_prod_zero_absorbing() {
1264        // MPROD with 0.0 = 0.0 (zero is absorbing element)
1265        let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.0, 0.8]);
1266        let input = make_memory_exec(batch);
1267        let result = execute_fold(
1268            input,
1269            vec![0],
1270            vec![FoldBinding {
1271                output_name: "prob".to_string(),
1272                kind: FoldAggKind::Prod,
1273                input_col_index: 1,
1274                input_col_name: None,
1275            }],
1276        )
1277        .await;
1278
1279        let vals = result
1280            .column(1)
1281            .as_any()
1282            .downcast_ref::<Float64Array>()
1283            .unwrap();
1284        assert!((vals.value(0) - 0.0).abs() < 1e-10);
1285    }
1286
1287    #[tokio::test]
1288    async fn test_prod_underflow_protection() {
1289        // 50 × 0.5 ≈ 8.88e-16, should not be exactly 0 thanks to log-space
1290        let names: Vec<&str> = vec!["a"; 50];
1291        let values: Vec<f64> = vec![0.5; 50];
1292        let batch = make_test_batch(names, values);
1293        let input = make_memory_exec(batch);
1294        let result = execute_fold(
1295            input,
1296            vec![0],
1297            vec![FoldBinding {
1298                output_name: "prob".to_string(),
1299                kind: FoldAggKind::Prod,
1300                input_col_index: 1,
1301                input_col_name: None,
1302            }],
1303        )
1304        .await;
1305
1306        let vals = result
1307            .column(1)
1308            .as_any()
1309            .downcast_ref::<Float64Array>()
1310            .unwrap();
1311        let expected = 0.5_f64.powi(50); // ≈ 8.88e-16
1312        assert!(vals.value(0) > 0.0, "should not underflow to zero");
1313        assert!(
1314            (vals.value(0) - expected).abs() / expected < 1e-6,
1315            "result {} should be close to expected {}",
1316            vals.value(0),
1317            expected
1318        );
1319    }
1320
1321    // ── MNOR/MPROD mathematical correctness tests ───────────────────────
1322
1323    fn make_nullable_test_batch(names: Vec<&str>, values: Vec<Option<f64>>) -> RecordBatch {
1324        let schema = Arc::new(Schema::new(vec![
1325            Field::new("name", DataType::Utf8, true),
1326            Field::new("value", DataType::Float64, true),
1327        ]));
1328        RecordBatch::try_new(
1329            schema,
1330            vec![
1331                Arc::new(StringArray::from(
1332                    names.into_iter().map(Some).collect::<Vec<_>>(),
1333                )),
1334                Arc::new(Float64Array::from(values)),
1335            ],
1336        )
1337        .unwrap()
1338    }
1339
1340    #[tokio::test]
1341    async fn test_nor_single_element() {
1342        // MNOR({0.7}) = 0.7 (n=1 identity)
1343        let batch = make_test_batch(vec!["a"], vec![0.7]);
1344        let input = make_memory_exec(batch);
1345        let result = execute_fold(
1346            input,
1347            vec![0],
1348            vec![FoldBinding {
1349                output_name: "prob".to_string(),
1350                kind: FoldAggKind::Nor,
1351                input_col_index: 1,
1352                input_col_name: None,
1353            }],
1354        )
1355        .await;
1356        let vals = result
1357            .column(1)
1358            .as_any()
1359            .downcast_ref::<Float64Array>()
1360            .unwrap();
1361        assert!((vals.value(0) - 0.7).abs() < 1e-10);
1362    }
1363
1364    #[tokio::test]
1365    async fn test_prod_single_element() {
1366        // MPROD({0.7}) = 0.7 (n=1 identity)
1367        let batch = make_test_batch(vec!["a"], vec![0.7]);
1368        let input = make_memory_exec(batch);
1369        let result = execute_fold(
1370            input,
1371            vec![0],
1372            vec![FoldBinding {
1373                output_name: "prob".to_string(),
1374                kind: FoldAggKind::Prod,
1375                input_col_index: 1,
1376                input_col_name: None,
1377            }],
1378        )
1379        .await;
1380        let vals = result
1381            .column(1)
1382            .as_any()
1383            .downcast_ref::<Float64Array>()
1384            .unwrap();
1385        assert!((vals.value(0) - 0.7).abs() < 1e-10);
1386    }
1387
1388    #[tokio::test]
1389    async fn test_nor_three_elements() {
1390        // MNOR({0.3, 0.4, 0.5}) = 1 - (0.7)(0.6)(0.5) = 0.79
1391        let batch = make_test_batch(vec!["a", "a", "a"], vec![0.3, 0.4, 0.5]);
1392        let input = make_memory_exec(batch);
1393        let result = execute_fold(
1394            input,
1395            vec![0],
1396            vec![FoldBinding {
1397                output_name: "prob".to_string(),
1398                kind: FoldAggKind::Nor,
1399                input_col_index: 1,
1400                input_col_name: None,
1401            }],
1402        )
1403        .await;
1404        let vals = result
1405            .column(1)
1406            .as_any()
1407            .downcast_ref::<Float64Array>()
1408            .unwrap();
1409        assert!((vals.value(0) - 0.79).abs() < 1e-10);
1410    }
1411
1412    #[tokio::test]
1413    async fn test_nor_four_elements_spec_example() {
1414        // Spec §4.5: MNOR({0.72, 0.54, 0.56, 0.42}) = 1 - (0.28)(0.46)(0.44)(0.58) = 0.96713024
1415        let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![0.72, 0.54, 0.56, 0.42]);
1416        let input = make_memory_exec(batch);
1417        let result = execute_fold(
1418            input,
1419            vec![0],
1420            vec![FoldBinding {
1421                output_name: "prob".to_string(),
1422                kind: FoldAggKind::Nor,
1423                input_col_index: 1,
1424                input_col_name: None,
1425            }],
1426        )
1427        .await;
1428        let vals = result
1429            .column(1)
1430            .as_any()
1431            .downcast_ref::<Float64Array>()
1432            .unwrap();
1433        assert!(
1434            (vals.value(0) - 0.96713024).abs() < 1e-10,
1435            "expected 0.96713024, got {}",
1436            vals.value(0)
1437        );
1438    }
1439
1440    #[tokio::test]
1441    async fn test_prod_three_elements() {
1442        // MPROD({0.5, 0.5, 0.5}) = 0.125
1443        let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.5, 0.5]);
1444        let input = make_memory_exec(batch);
1445        let result = execute_fold(
1446            input,
1447            vec![0],
1448            vec![FoldBinding {
1449                output_name: "prob".to_string(),
1450                kind: FoldAggKind::Prod,
1451                input_col_index: 1,
1452                input_col_name: None,
1453            }],
1454        )
1455        .await;
1456        let vals = result
1457            .column(1)
1458            .as_any()
1459            .downcast_ref::<Float64Array>()
1460            .unwrap();
1461        assert!((vals.value(0) - 0.125).abs() < 1e-10);
1462    }
1463
1464    #[tokio::test]
1465    async fn test_nor_absorbing_element() {
1466        // p=1.0 absorbs: MNOR({0.3, 1.0}) = 1.0
1467        let batch = make_test_batch(vec!["a", "a"], vec![0.3, 1.0]);
1468        let input = make_memory_exec(batch);
1469        let result = execute_fold(
1470            input,
1471            vec![0],
1472            vec![FoldBinding {
1473                output_name: "prob".to_string(),
1474                kind: FoldAggKind::Nor,
1475                input_col_index: 1,
1476                input_col_name: None,
1477            }],
1478        )
1479        .await;
1480        let vals = result
1481            .column(1)
1482            .as_any()
1483            .downcast_ref::<Float64Array>()
1484            .unwrap();
1485        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1486    }
1487
1488    #[tokio::test]
1489    async fn test_prod_clamping() {
1490        // Out-of-range 2.0 clamped to 1.0: MPROD({2.0, 0.5}) = 1.0 * 0.5 = 0.5
1491        let batch = make_test_batch(vec!["a", "a"], vec![2.0, 0.5]);
1492        let input = make_memory_exec(batch);
1493        let result = execute_fold(
1494            input,
1495            vec![0],
1496            vec![FoldBinding {
1497                output_name: "prob".to_string(),
1498                kind: FoldAggKind::Prod,
1499                input_col_index: 1,
1500                input_col_name: None,
1501            }],
1502        )
1503        .await;
1504        let vals = result
1505            .column(1)
1506            .as_any()
1507            .downcast_ref::<Float64Array>()
1508            .unwrap();
1509        assert!((vals.value(0) - 0.5).abs() < 1e-10);
1510    }
1511
1512    #[tokio::test]
1513    async fn test_prod_multiple_groups() {
1514        // a: MPROD({0.6, 0.8}) = 0.48, b: MPROD({0.5, 0.5}) = 0.25
1515        let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.6, 0.8, 0.5, 0.5]);
1516        let input = make_memory_exec(batch);
1517        let result = execute_fold(
1518            input,
1519            vec![0],
1520            vec![FoldBinding {
1521                output_name: "prob".to_string(),
1522                kind: FoldAggKind::Prod,
1523                input_col_index: 1,
1524                input_col_name: None,
1525            }],
1526        )
1527        .await;
1528
1529        assert_eq!(result.num_rows(), 2);
1530        let names = result
1531            .column(0)
1532            .as_any()
1533            .downcast_ref::<StringArray>()
1534            .unwrap();
1535        let vals = result
1536            .column(1)
1537            .as_any()
1538            .downcast_ref::<Float64Array>()
1539            .unwrap();
1540        for i in 0..2 {
1541            match names.value(i) {
1542                "a" => assert!((vals.value(i) - 0.48).abs() < 1e-10),
1543                "b" => assert!((vals.value(i) - 0.25).abs() < 1e-10),
1544                _ => panic!("unexpected group name"),
1545            }
1546        }
1547    }
1548
1549    #[tokio::test]
1550    async fn test_nor_commutativity() {
1551        // Order independence: MNOR({0.2, 0.5, 0.8}) = MNOR({0.8, 0.5, 0.2}) = 0.92
1552        let fwd = make_test_batch(vec!["a", "a", "a"], vec![0.2, 0.5, 0.8]);
1553        let rev = make_test_batch(vec!["a", "a", "a"], vec![0.8, 0.5, 0.2]);
1554        let binding = vec![FoldBinding {
1555            output_name: "prob".to_string(),
1556            kind: FoldAggKind::Nor,
1557            input_col_index: 1,
1558            input_col_name: None,
1559        }];
1560        let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1561        let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1562        let v1 = r1
1563            .column(1)
1564            .as_any()
1565            .downcast_ref::<Float64Array>()
1566            .unwrap()
1567            .value(0);
1568        let v2 = r2
1569            .column(1)
1570            .as_any()
1571            .downcast_ref::<Float64Array>()
1572            .unwrap()
1573            .value(0);
1574        assert!((v1 - 0.92).abs() < 1e-10);
1575        assert!((v2 - 0.92).abs() < 1e-10);
1576        assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1577    }
1578
1579    #[tokio::test]
1580    async fn test_prod_commutativity() {
1581        // Order independence: MPROD({0.5, 0.25}) = MPROD({0.25, 0.5}) = 0.125
1582        let fwd = make_test_batch(vec!["a", "a"], vec![0.5, 0.25]);
1583        let rev = make_test_batch(vec!["a", "a"], vec![0.25, 0.5]);
1584        let binding = vec![FoldBinding {
1585            output_name: "prob".to_string(),
1586            kind: FoldAggKind::Prod,
1587            input_col_index: 1,
1588            input_col_name: None,
1589        }];
1590        let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1591        let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1592        let v1 = r1
1593            .column(1)
1594            .as_any()
1595            .downcast_ref::<Float64Array>()
1596            .unwrap()
1597            .value(0);
1598        let v2 = r2
1599            .column(1)
1600            .as_any()
1601            .downcast_ref::<Float64Array>()
1602            .unwrap()
1603            .value(0);
1604        assert!((v1 - 0.125).abs() < 1e-10);
1605        assert!((v2 - 0.125).abs() < 1e-10);
1606        assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1607    }
1608
1609    #[tokio::test]
1610    async fn test_nor_boundary_near_zero() {
1611        // Precision near 0: MNOR({0.001, 0.002}) = 1 - (0.999)(0.998) = 0.002998
1612        let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1613        let input = make_memory_exec(batch);
1614        let result = execute_fold(
1615            input,
1616            vec![0],
1617            vec![FoldBinding {
1618                output_name: "prob".to_string(),
1619                kind: FoldAggKind::Nor,
1620                input_col_index: 1,
1621                input_col_name: None,
1622            }],
1623        )
1624        .await;
1625        let vals = result
1626            .column(1)
1627            .as_any()
1628            .downcast_ref::<Float64Array>()
1629            .unwrap();
1630        let expected = 1.0 - 0.999 * 0.998;
1631        assert!(
1632            (vals.value(0) - expected).abs() < 1e-10,
1633            "expected {}, got {}",
1634            expected,
1635            vals.value(0)
1636        );
1637    }
1638
1639    #[tokio::test]
1640    async fn test_nor_boundary_near_one() {
1641        // Precision near 1: MNOR({0.999, 0.998}) = 1 - (0.001)(0.002) = 0.999998
1642        let batch = make_test_batch(vec!["a", "a"], vec![0.999, 0.998]);
1643        let input = make_memory_exec(batch);
1644        let result = execute_fold(
1645            input,
1646            vec![0],
1647            vec![FoldBinding {
1648                output_name: "prob".to_string(),
1649                kind: FoldAggKind::Nor,
1650                input_col_index: 1,
1651                input_col_name: None,
1652            }],
1653        )
1654        .await;
1655        let vals = result
1656            .column(1)
1657            .as_any()
1658            .downcast_ref::<Float64Array>()
1659            .unwrap();
1660        let expected = 1.0 - 0.001 * 0.002;
1661        assert!(
1662            (vals.value(0) - expected).abs() < 1e-10,
1663            "expected {}, got {}",
1664            expected,
1665            vals.value(0)
1666        );
1667    }
1668
1669    #[tokio::test]
1670    async fn test_prod_boundary_near_zero() {
1671        // Precision near 0: MPROD({0.001, 0.002}) = 2e-6
1672        let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1673        let input = make_memory_exec(batch);
1674        let result = execute_fold(
1675            input,
1676            vec![0],
1677            vec![FoldBinding {
1678                output_name: "prob".to_string(),
1679                kind: FoldAggKind::Prod,
1680                input_col_index: 1,
1681                input_col_name: None,
1682            }],
1683        )
1684        .await;
1685        let vals = result
1686            .column(1)
1687            .as_any()
1688            .downcast_ref::<Float64Array>()
1689            .unwrap();
1690        assert!(
1691            (vals.value(0) - 2e-6).abs() < 1e-15,
1692            "expected 2e-6, got {}",
1693            vals.value(0)
1694        );
1695    }
1696
1697    #[tokio::test]
1698    async fn test_nor_empty_input() {
1699        // Empty input → 0 rows output
1700        let schema = Arc::new(Schema::new(vec![
1701            Field::new("name", DataType::Utf8, true),
1702            Field::new("value", DataType::Float64, true),
1703        ]));
1704        let batch = RecordBatch::new_empty(schema);
1705        let input = make_memory_exec(batch);
1706        let result = execute_fold(
1707            input,
1708            vec![0],
1709            vec![FoldBinding {
1710                output_name: "prob".to_string(),
1711                kind: FoldAggKind::Nor,
1712                input_col_index: 1,
1713                input_col_name: None,
1714            }],
1715        )
1716        .await;
1717        assert_eq!(result.num_rows(), 0);
1718    }
1719
1720    #[tokio::test]
1721    async fn test_nor_nan_handling() {
1722        // NaN propagates through noisy-OR
1723        let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::NAN]);
1724        let input = make_memory_exec(batch);
1725        let result = execute_fold(
1726            input,
1727            vec![0],
1728            vec![FoldBinding {
1729                output_name: "prob".to_string(),
1730                kind: FoldAggKind::Nor,
1731                input_col_index: 1,
1732                input_col_name: None,
1733            }],
1734        )
1735        .await;
1736        let vals = result
1737            .column(1)
1738            .as_any()
1739            .downcast_ref::<Float64Array>()
1740            .unwrap();
1741        assert!(vals.value(0).is_nan(), "NaN should propagate through MNOR");
1742    }
1743
1744    #[tokio::test]
1745    async fn test_prod_nan_handling() {
1746        // NaN propagates through product
1747        let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::NAN]);
1748        let input = make_memory_exec(batch);
1749        let result = execute_fold(
1750            input,
1751            vec![0],
1752            vec![FoldBinding {
1753                output_name: "prob".to_string(),
1754                kind: FoldAggKind::Prod,
1755                input_col_index: 1,
1756                input_col_name: None,
1757            }],
1758        )
1759        .await;
1760        let vals = result
1761            .column(1)
1762            .as_any()
1763            .downcast_ref::<Float64Array>()
1764            .unwrap();
1765        assert!(vals.value(0).is_nan(), "NaN should propagate through MPROD");
1766    }
1767
1768    #[tokio::test]
1769    async fn test_prod_infinity_handling() {
1770        // +∞ clamped to 1.0: MPROD({0.5, ∞}) = 0.5 * 1.0 = 0.5
1771        let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::INFINITY]);
1772        let input = make_memory_exec(batch);
1773        let result = execute_fold(
1774            input,
1775            vec![0],
1776            vec![FoldBinding {
1777                output_name: "prob".to_string(),
1778                kind: FoldAggKind::Prod,
1779                input_col_index: 1,
1780                input_col_name: None,
1781            }],
1782        )
1783        .await;
1784        let vals = result
1785            .column(1)
1786            .as_any()
1787            .downcast_ref::<Float64Array>()
1788            .unwrap();
1789        assert!((vals.value(0) - 0.5).abs() < 1e-10);
1790    }
1791
1792    #[tokio::test]
1793    async fn test_nor_infinity_handling() {
1794        // +∞ clamped to 1.0, which absorbs: MNOR({0.3, ∞}) = 1.0
1795        let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::INFINITY]);
1796        let input = make_memory_exec(batch);
1797        let result = execute_fold(
1798            input,
1799            vec![0],
1800            vec![FoldBinding {
1801                output_name: "prob".to_string(),
1802                kind: FoldAggKind::Nor,
1803                input_col_index: 1,
1804                input_col_name: None,
1805            }],
1806        )
1807        .await;
1808        let vals = result
1809            .column(1)
1810            .as_any()
1811            .downcast_ref::<Float64Array>()
1812            .unwrap();
1813        assert!((vals.value(0) - 1.0).abs() < 1e-10);
1814    }
1815
1816    #[tokio::test]
1817    async fn test_nor_all_null_values() {
1818        // All-null input → null output
1819        let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1820        let input = make_memory_exec(batch);
1821        let result = execute_fold(
1822            input,
1823            vec![0],
1824            vec![FoldBinding {
1825                output_name: "prob".to_string(),
1826                kind: FoldAggKind::Nor,
1827                input_col_index: 1,
1828                input_col_name: None,
1829            }],
1830        )
1831        .await;
1832        assert_eq!(result.num_rows(), 1);
1833        let vals = result
1834            .column(1)
1835            .as_any()
1836            .downcast_ref::<Float64Array>()
1837            .unwrap();
1838        assert!(vals.is_null(0), "all-null MNOR should produce null");
1839    }
1840
1841    #[tokio::test]
1842    async fn test_prod_all_null_values() {
1843        // All-null input → null output
1844        let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1845        let input = make_memory_exec(batch);
1846        let result = execute_fold(
1847            input,
1848            vec![0],
1849            vec![FoldBinding {
1850                output_name: "prob".to_string(),
1851                kind: FoldAggKind::Prod,
1852                input_col_index: 1,
1853                input_col_name: None,
1854            }],
1855        )
1856        .await;
1857        assert_eq!(result.num_rows(), 1);
1858        let vals = result
1859            .column(1)
1860            .as_any()
1861            .downcast_ref::<Float64Array>()
1862            .unwrap();
1863        assert!(vals.is_null(0), "all-null MPROD should produce null");
1864    }
1865
1866    #[tokio::test]
1867    async fn test_nor_mixed_null_values() {
1868        // Nulls skipped: MNOR({0.3, null, 0.5}) = 1 - (0.7)(0.5) = 0.65
1869        let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.3), None, Some(0.5)]);
1870        let input = make_memory_exec(batch);
1871        let result = execute_fold(
1872            input,
1873            vec![0],
1874            vec![FoldBinding {
1875                output_name: "prob".to_string(),
1876                kind: FoldAggKind::Nor,
1877                input_col_index: 1,
1878                input_col_name: None,
1879            }],
1880        )
1881        .await;
1882        let vals = result
1883            .column(1)
1884            .as_any()
1885            .downcast_ref::<Float64Array>()
1886            .unwrap();
1887        assert!((vals.value(0) - 0.65).abs() < 1e-10);
1888    }
1889
1890    #[tokio::test]
1891    async fn test_prod_mixed_null_values() {
1892        // Nulls skipped: MPROD({0.6, null, 0.8}) = 0.6 * 0.8 = 0.48
1893        let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.6), None, Some(0.8)]);
1894        let input = make_memory_exec(batch);
1895        let result = execute_fold(
1896            input,
1897            vec![0],
1898            vec![FoldBinding {
1899                output_name: "prob".to_string(),
1900                kind: FoldAggKind::Prod,
1901                input_col_index: 1,
1902                input_col_name: None,
1903            }],
1904        )
1905        .await;
1906        let vals = result
1907            .column(1)
1908            .as_any()
1909            .downcast_ref::<Float64Array>()
1910            .unwrap();
1911        assert!((vals.value(0) - 0.48).abs() < 1e-10);
1912    }
1913
1914    #[tokio::test]
1915    async fn test_nor_many_small_values() {
1916        // Large accumulation: 20 × 0.1 → 1 - 0.9^20 ≈ 0.8784
1917        let names: Vec<&str> = vec!["a"; 20];
1918        let values: Vec<f64> = vec![0.1; 20];
1919        let batch = make_test_batch(names, values);
1920        let input = make_memory_exec(batch);
1921        let result = execute_fold(
1922            input,
1923            vec![0],
1924            vec![FoldBinding {
1925                output_name: "prob".to_string(),
1926                kind: FoldAggKind::Nor,
1927                input_col_index: 1,
1928                input_col_name: None,
1929            }],
1930        )
1931        .await;
1932        let vals = result
1933            .column(1)
1934            .as_any()
1935            .downcast_ref::<Float64Array>()
1936            .unwrap();
1937        let expected = 1.0 - 0.9_f64.powi(20);
1938        assert!(
1939            (vals.value(0) - expected).abs() < 1e-10,
1940            "expected {}, got {}",
1941            expected,
1942            vals.value(0)
1943        );
1944    }
1945
1946    // ── FoldAggKind classification tests (Phase 1) ────────────────────────
1947
1948    #[test]
1949    fn test_is_monotonic() {
1950        assert!(FoldAggKind::Sum.is_monotonic());
1951        assert!(FoldAggKind::Max.is_monotonic());
1952        assert!(FoldAggKind::Min.is_monotonic());
1953        assert!(FoldAggKind::Count.is_monotonic());
1954        assert!(FoldAggKind::Nor.is_monotonic());
1955        assert!(FoldAggKind::Prod.is_monotonic());
1956        assert!(!FoldAggKind::Avg.is_monotonic());
1957        assert!(!FoldAggKind::Collect.is_monotonic());
1958    }
1959
1960    #[test]
1961    fn test_monotonicity_direction() {
1962        use super::MonotonicDirection;
1963        assert_eq!(
1964            FoldAggKind::Sum.monotonicity_direction(),
1965            Some(MonotonicDirection::NonDecreasing)
1966        );
1967        assert_eq!(
1968            FoldAggKind::Max.monotonicity_direction(),
1969            Some(MonotonicDirection::NonDecreasing)
1970        );
1971        assert_eq!(
1972            FoldAggKind::Count.monotonicity_direction(),
1973            Some(MonotonicDirection::NonDecreasing)
1974        );
1975        assert_eq!(
1976            FoldAggKind::Nor.monotonicity_direction(),
1977            Some(MonotonicDirection::NonDecreasing)
1978        );
1979        assert_eq!(
1980            FoldAggKind::Min.monotonicity_direction(),
1981            Some(MonotonicDirection::NonIncreasing)
1982        );
1983        assert_eq!(
1984            FoldAggKind::Prod.monotonicity_direction(),
1985            Some(MonotonicDirection::NonIncreasing)
1986        );
1987        assert_eq!(FoldAggKind::Avg.monotonicity_direction(), None);
1988        assert_eq!(FoldAggKind::Collect.monotonicity_direction(), None);
1989    }
1990
1991    #[test]
1992    fn test_identity_values() {
1993        assert_eq!(FoldAggKind::Sum.identity(), Some(0.0));
1994        assert_eq!(FoldAggKind::Count.identity(), Some(0.0));
1995        assert_eq!(FoldAggKind::Nor.identity(), Some(0.0));
1996        assert_eq!(FoldAggKind::Max.identity(), Some(f64::NEG_INFINITY));
1997        assert_eq!(FoldAggKind::Min.identity(), Some(f64::INFINITY));
1998        assert_eq!(FoldAggKind::Prod.identity(), Some(1.0));
1999        assert_eq!(FoldAggKind::Avg.identity(), None);
2000        assert_eq!(FoldAggKind::Collect.identity(), None);
2001    }
2002
2003    // ── Strict mode tests (Phase 5) ──────────────────────────────────────
2004
2005    async fn execute_fold_strict(
2006        input: Arc<dyn ExecutionPlan>,
2007        key_indices: Vec<usize>,
2008        fold_bindings: Vec<FoldBinding>,
2009        strict: bool,
2010    ) -> DFResult<RecordBatch> {
2011        let exec = FoldExec::new(input, key_indices, fold_bindings, strict, 1e-15);
2012        let ctx = SessionContext::new();
2013        let task_ctx = ctx.task_ctx();
2014        let stream = exec.execute(0, task_ctx).unwrap();
2015        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream).await?;
2016        if batches.is_empty() {
2017            Ok(RecordBatch::new_empty(exec.schema()))
2018        } else {
2019            arrow::compute::concat_batches(&exec.schema(), &batches).map_err(arrow_err)
2020        }
2021    }
2022
2023    #[tokio::test]
2024    async fn test_nor_strict_rejects_above_one() {
2025        let batch = make_test_batch(vec!["a"], vec![1.5]);
2026        let input = make_memory_exec(batch);
2027        let result = execute_fold_strict(
2028            input,
2029            vec![0],
2030            vec![FoldBinding {
2031                output_name: "p".into(),
2032                kind: FoldAggKind::Nor,
2033                input_col_index: 1,
2034                input_col_name: None,
2035            }],
2036            true,
2037        )
2038        .await;
2039        assert!(result.is_err());
2040        let err = result.unwrap_err().to_string();
2041        assert!(
2042            err.contains("strict_probability_domain"),
2043            "Expected strict error, got: {}",
2044            err
2045        );
2046    }
2047
2048    #[tokio::test]
2049    async fn test_nor_strict_rejects_negative() {
2050        let batch = make_test_batch(vec!["a"], vec![-0.1]);
2051        let input = make_memory_exec(batch);
2052        let result = execute_fold_strict(
2053            input,
2054            vec![0],
2055            vec![FoldBinding {
2056                output_name: "p".into(),
2057                kind: FoldAggKind::Nor,
2058                input_col_index: 1,
2059                input_col_name: None,
2060            }],
2061            true,
2062        )
2063        .await;
2064        assert!(result.is_err());
2065        let err = result.unwrap_err().to_string();
2066        assert!(
2067            err.contains("strict_probability_domain"),
2068            "Expected strict error, got: {}",
2069            err
2070        );
2071    }
2072
2073    #[tokio::test]
2074    async fn test_prod_strict_rejects_above_one() {
2075        let batch = make_test_batch(vec!["a"], vec![2.0]);
2076        let input = make_memory_exec(batch);
2077        let result = execute_fold_strict(
2078            input,
2079            vec![0],
2080            vec![FoldBinding {
2081                output_name: "p".into(),
2082                kind: FoldAggKind::Prod,
2083                input_col_index: 1,
2084                input_col_name: None,
2085            }],
2086            true,
2087        )
2088        .await;
2089        assert!(result.is_err());
2090        let err = result.unwrap_err().to_string();
2091        assert!(
2092            err.contains("strict_probability_domain"),
2093            "Expected strict error, got: {}",
2094            err
2095        );
2096    }
2097
2098    #[tokio::test]
2099    async fn test_prod_strict_rejects_negative() {
2100        let batch = make_test_batch(vec!["a"], vec![-0.5]);
2101        let input = make_memory_exec(batch);
2102        let result = execute_fold_strict(
2103            input,
2104            vec![0],
2105            vec![FoldBinding {
2106                output_name: "p".into(),
2107                kind: FoldAggKind::Prod,
2108                input_col_index: 1,
2109                input_col_name: None,
2110            }],
2111            true,
2112        )
2113        .await;
2114        assert!(result.is_err());
2115        let err = result.unwrap_err().to_string();
2116        assert!(
2117            err.contains("strict_probability_domain"),
2118            "Expected strict error, got: {}",
2119            err
2120        );
2121    }
2122
2123    #[tokio::test]
2124    async fn test_nor_strict_accepts_valid() {
2125        let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
2126        let input = make_memory_exec(batch);
2127        let result = execute_fold_strict(
2128            input,
2129            vec![0],
2130            vec![FoldBinding {
2131                output_name: "p".into(),
2132                kind: FoldAggKind::Nor,
2133                input_col_index: 1,
2134                input_col_name: None,
2135            }],
2136            true,
2137        )
2138        .await;
2139        assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
2140        let batch = result.unwrap();
2141        let vals = batch
2142            .column(1)
2143            .as_any()
2144            .downcast_ref::<Float64Array>()
2145            .unwrap();
2146        let expected = 0.65; // 1 - (1-0.3)(1-0.5)
2147        assert!(
2148            (vals.value(0) - expected).abs() < 1e-10,
2149            "expected {}, got {}",
2150            expected,
2151            vals.value(0)
2152        );
2153    }
2154
2155    #[tokio::test]
2156    async fn test_count_all_groups_by_key() {
2157        // Two groups: "a" (2 rows), "b" (1 row)
2158        let batch = make_test_batch(vec!["a", "a", "b"], vec![10.0, 20.0, 30.0]);
2159        let input = make_memory_exec(batch);
2160        let result = execute_fold(
2161            input,
2162            vec![0],
2163            vec![FoldBinding {
2164                output_name: "cnt".to_string(),
2165                kind: FoldAggKind::CountAll,
2166                input_col_index: 0, // unused for CountAll
2167                input_col_name: None,
2168            }],
2169        )
2170        .await;
2171
2172        assert_eq!(result.num_rows(), 2, "Should have 2 groups");
2173        let counts = result
2174            .column(1)
2175            .as_any()
2176            .downcast_ref::<Int64Array>()
2177            .unwrap();
2178        assert_eq!(counts.value(0), 2, "Group 'a' should have count 2");
2179        assert_eq!(counts.value(1), 1, "Group 'b' should have count 1");
2180    }
2181}