Skip to main content

uni_query/query/df_graph/
locy_fold.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! FOLD operator for Locy.
5//!
6//! `FoldExec` applies fold (lattice-join) semantics: for each group of rows sharing
7//! the same KEY columns, it reduces non-key columns via their declared fold functions.
8
9use crate::query::df_graph::common::{ScalarKey, compute_plan_properties, extract_scalar_key};
10use arrow_array::builder::{Float64Builder, Int64Builder, LargeBinaryBuilder};
11use arrow_array::{Array, Float64Array, Int64Array, RecordBatch};
12use arrow_schema::{DataType, Field, Schema, SchemaRef};
13use datafusion::common::Result as DFResult;
14use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
15use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
16use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
17use futures::{Stream, TryStreamExt};
18use std::any::Any;
19use std::collections::HashMap;
20use std::fmt;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25/// Aggregate function kind for FOLD bindings.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum FoldAggKind {
28    Sum,
29    Max,
30    Min,
31    Count,
32    Avg,
33    Collect,
34}
35
36/// A single FOLD binding: aggregate an input column into an output column.
37#[derive(Debug, Clone)]
38pub struct FoldBinding {
39    pub output_name: String,
40    pub kind: FoldAggKind,
41    pub input_col_index: usize,
42}
43
44/// DataFusion `ExecutionPlan` that applies FOLD semantics.
45///
46/// Groups rows by KEY columns and computes aggregates (SUM, MAX, MIN, COUNT, AVG, COLLECT)
47/// for each fold binding. Output schema is KEY columns + fold output columns.
48#[derive(Debug)]
49pub struct FoldExec {
50    input: Arc<dyn ExecutionPlan>,
51    key_indices: Vec<usize>,
52    fold_bindings: Vec<FoldBinding>,
53    schema: SchemaRef,
54    properties: PlanProperties,
55    metrics: ExecutionPlanMetricsSet,
56}
57
58impl FoldExec {
59    /// Create a new `FoldExec`.
60    ///
61    /// # Arguments
62    /// * `input` - Child execution plan
63    /// * `key_indices` - Indices of KEY columns for grouping
64    /// * `fold_bindings` - Aggregate bindings (output name, kind, input col index)
65    pub fn new(
66        input: Arc<dyn ExecutionPlan>,
67        key_indices: Vec<usize>,
68        fold_bindings: Vec<FoldBinding>,
69    ) -> Self {
70        let input_schema = input.schema();
71        let schema = Self::build_output_schema(&input_schema, &key_indices, &fold_bindings);
72        let properties = compute_plan_properties(Arc::clone(&schema));
73
74        Self {
75            input,
76            key_indices,
77            fold_bindings,
78            schema,
79            properties,
80            metrics: ExecutionPlanMetricsSet::new(),
81        }
82    }
83
84    fn build_output_schema(
85        input_schema: &SchemaRef,
86        key_indices: &[usize],
87        fold_bindings: &[FoldBinding],
88    ) -> SchemaRef {
89        let mut fields = Vec::new();
90
91        // Key columns preserve original types
92        for &ki in key_indices {
93            fields.push(Arc::new(input_schema.field(ki).clone()));
94        }
95
96        // Fold output columns
97        for binding in fold_bindings {
98            let input_type = input_schema.field(binding.input_col_index).data_type();
99            let output_type = match binding.kind {
100                FoldAggKind::Sum | FoldAggKind::Avg => DataType::Float64,
101                FoldAggKind::Count => DataType::Int64,
102                FoldAggKind::Max | FoldAggKind::Min => input_type.clone(),
103                FoldAggKind::Collect => DataType::LargeBinary,
104            };
105            fields.push(Arc::new(Field::new(
106                &binding.output_name,
107                output_type,
108                true,
109            )));
110        }
111
112        Arc::new(Schema::new(fields))
113    }
114}
115
116impl DisplayAs for FoldExec {
117    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        write!(
119            f,
120            "FoldExec: key_indices={:?}, bindings={:?}",
121            self.key_indices, self.fold_bindings
122        )
123    }
124}
125
126impl ExecutionPlan for FoldExec {
127    fn name(&self) -> &str {
128        "FoldExec"
129    }
130
131    fn as_any(&self) -> &dyn Any {
132        self
133    }
134
135    fn schema(&self) -> SchemaRef {
136        Arc::clone(&self.schema)
137    }
138
139    fn properties(&self) -> &PlanProperties {
140        &self.properties
141    }
142
143    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
144        vec![&self.input]
145    }
146
147    fn with_new_children(
148        self: Arc<Self>,
149        children: Vec<Arc<dyn ExecutionPlan>>,
150    ) -> DFResult<Arc<dyn ExecutionPlan>> {
151        if children.len() != 1 {
152            return Err(datafusion::error::DataFusionError::Plan(
153                "FoldExec requires exactly one child".to_string(),
154            ));
155        }
156        Ok(Arc::new(Self::new(
157            Arc::clone(&children[0]),
158            self.key_indices.clone(),
159            self.fold_bindings.clone(),
160        )))
161    }
162
163    fn execute(
164        &self,
165        partition: usize,
166        context: Arc<TaskContext>,
167    ) -> DFResult<SendableRecordBatchStream> {
168        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
169        let metrics = BaselineMetrics::new(&self.metrics, partition);
170        let key_indices = self.key_indices.clone();
171        let fold_bindings = self.fold_bindings.clone();
172        let output_schema = Arc::clone(&self.schema);
173        let input_schema = self.input.schema();
174
175        let fut = async move {
176            let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
177
178            if batches.is_empty() {
179                return Ok(RecordBatch::new_empty(output_schema));
180            }
181
182            let batch = arrow::compute::concat_batches(&input_schema, &batches)
183                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
184
185            if batch.num_rows() == 0 {
186                return Ok(RecordBatch::new_empty(output_schema));
187            }
188
189            // Group by key columns → row indices
190            let mut groups: HashMap<Vec<ScalarKey>, Vec<usize>> = HashMap::new();
191            for row_idx in 0..batch.num_rows() {
192                let key = extract_scalar_key(&batch, &key_indices, row_idx);
193                groups.entry(key).or_default().push(row_idx);
194            }
195
196            // Preserve insertion order by collecting keys in order of first appearance
197            let mut ordered_keys: Vec<Vec<ScalarKey>> = Vec::new();
198            {
199                let mut seen: std::collections::HashSet<Vec<ScalarKey>> =
200                    std::collections::HashSet::new();
201                for row_idx in 0..batch.num_rows() {
202                    let key = extract_scalar_key(&batch, &key_indices, row_idx);
203                    if seen.insert(key.clone()) {
204                        ordered_keys.push(key);
205                    }
206                }
207            }
208
209            let num_groups = ordered_keys.len();
210
211            // Build output columns
212            let mut output_columns: Vec<arrow_array::ArrayRef> = Vec::new();
213
214            // Key columns: take from first row of each group
215            for &ki in &key_indices {
216                let col = batch.column(ki);
217                let first_indices: Vec<u32> =
218                    ordered_keys.iter().map(|k| groups[k][0] as u32).collect();
219                let idx_array = arrow_array::UInt32Array::from(first_indices);
220                let taken = arrow::compute::take(col.as_ref(), &idx_array, None).map_err(|e| {
221                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
222                })?;
223                output_columns.push(taken);
224            }
225
226            // Fold binding columns: compute aggregates per group
227            for binding in &fold_bindings {
228                let col = batch.column(binding.input_col_index);
229                let agg_col = compute_fold_aggregate(
230                    col.as_ref(),
231                    &binding.kind,
232                    &ordered_keys,
233                    &groups,
234                    num_groups,
235                )?;
236                output_columns.push(agg_col);
237            }
238
239            RecordBatch::try_new(output_schema, output_columns)
240                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
241        };
242
243        Ok(Box::pin(FoldStream {
244            state: FoldStreamState::Running(Box::pin(fut)),
245            schema: Arc::clone(&self.schema),
246            metrics,
247        }))
248    }
249
250    fn metrics(&self) -> Option<MetricsSet> {
251        Some(self.metrics.clone_inner())
252    }
253}
254
255// ---------------------------------------------------------------------------
256// Aggregate computation
257// ---------------------------------------------------------------------------
258
259fn compute_fold_aggregate(
260    col: &dyn Array,
261    kind: &FoldAggKind,
262    ordered_keys: &[Vec<ScalarKey>],
263    groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
264    num_groups: usize,
265) -> DFResult<arrow_array::ArrayRef> {
266    match kind {
267        FoldAggKind::Sum => {
268            let mut builder = Float64Builder::with_capacity(num_groups);
269            for key in ordered_keys {
270                let indices = &groups[key];
271                let sum = sum_f64(col, indices);
272                match sum {
273                    Some(v) => builder.append_value(v),
274                    None => builder.append_null(),
275                }
276            }
277            Ok(Arc::new(builder.finish()))
278        }
279        FoldAggKind::Count => {
280            let mut builder = Int64Builder::with_capacity(num_groups);
281            for key in ordered_keys {
282                let indices = &groups[key];
283                let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
284                builder.append_value(count as i64);
285            }
286            Ok(Arc::new(builder.finish()))
287        }
288        FoldAggKind::Max => compute_minmax(col, ordered_keys, groups, num_groups, false),
289        FoldAggKind::Min => compute_minmax(col, ordered_keys, groups, num_groups, true),
290        FoldAggKind::Avg => {
291            let mut builder = Float64Builder::with_capacity(num_groups);
292            for key in ordered_keys {
293                let indices = &groups[key];
294                let sum = sum_f64(col, indices);
295                let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
296                match (sum, count) {
297                    (Some(s), c) if c > 0 => builder.append_value(s / c as f64),
298                    _ => builder.append_null(),
299                }
300            }
301            Ok(Arc::new(builder.finish()))
302        }
303        FoldAggKind::Collect => {
304            let mut builder = LargeBinaryBuilder::with_capacity(num_groups, num_groups * 32);
305            for key in ordered_keys {
306                let indices = &groups[key];
307                let mut values = Vec::new();
308                for &i in indices {
309                    if !col.is_null(i) {
310                        let val = scalar_to_value(col, i);
311                        values.push(val);
312                    }
313                }
314                let list = uni_common::Value::List(values);
315                let encoded = uni_common::cypher_value_codec::encode(&list);
316                builder.append_value(&encoded);
317            }
318            Ok(Arc::new(builder.finish()))
319        }
320    }
321}
322
323fn sum_f64(col: &dyn Array, indices: &[usize]) -> Option<f64> {
324    let mut sum = 0.0;
325    let mut has_value = false;
326    for &i in indices {
327        if col.is_null(i) {
328            continue;
329        }
330        has_value = true;
331        if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
332            sum += arr.value(i);
333        } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
334            sum += arr.value(i) as f64;
335        }
336    }
337    if has_value { Some(sum) } else { None }
338}
339
340fn compute_minmax(
341    col: &dyn Array,
342    ordered_keys: &[Vec<ScalarKey>],
343    groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
344    num_groups: usize,
345    is_min: bool,
346) -> DFResult<arrow_array::ArrayRef> {
347    match col.data_type() {
348        DataType::Int64 => {
349            let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
350            let mut builder = Int64Builder::with_capacity(num_groups);
351            for key in ordered_keys {
352                let indices = &groups[key];
353                let mut result: Option<i64> = None;
354                for &i in indices {
355                    if arr.is_null(i) {
356                        continue;
357                    }
358                    let v = arr.value(i);
359                    result = Some(match result {
360                        None => v,
361                        Some(cur) if is_min => cur.min(v),
362                        Some(cur) => cur.max(v),
363                    });
364                }
365                match result {
366                    Some(v) => builder.append_value(v),
367                    None => builder.append_null(),
368                }
369            }
370            Ok(Arc::new(builder.finish()))
371        }
372        DataType::Float64 => {
373            let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
374            let mut builder = Float64Builder::with_capacity(num_groups);
375            for key in ordered_keys {
376                let indices = &groups[key];
377                let mut result: Option<f64> = None;
378                for &i in indices {
379                    if arr.is_null(i) {
380                        continue;
381                    }
382                    let v = arr.value(i);
383                    result = Some(match result {
384                        None => v,
385                        Some(cur) if is_min => cur.min(v),
386                        Some(cur) => cur.max(v),
387                    });
388                }
389                match result {
390                    Some(v) => builder.append_value(v),
391                    None => builder.append_null(),
392                }
393            }
394            Ok(Arc::new(builder.finish()))
395        }
396        _ => {
397            // Fallback: treat as string comparison
398            let mut builder = arrow_array::builder::StringBuilder::new();
399            for key in ordered_keys {
400                let indices = &groups[key];
401                let mut result: Option<String> = None;
402                for &i in indices {
403                    if col.is_null(i) {
404                        continue;
405                    }
406                    let v = format!("{:?}", scalar_to_value(col, i));
407                    result = Some(match result {
408                        None => v.clone(),
409                        Some(cur) => {
410                            if is_min {
411                                if v < cur { v } else { cur }
412                            } else if v > cur {
413                                v
414                            } else {
415                                cur
416                            }
417                        }
418                    });
419                }
420                match result {
421                    Some(v) => builder.append_value(&v),
422                    None => builder.append_null(),
423                }
424            }
425            Ok(Arc::new(builder.finish()))
426        }
427    }
428}
429
430fn scalar_to_value(col: &dyn Array, row_idx: usize) -> uni_common::Value {
431    if col.is_null(row_idx) {
432        return uni_common::Value::Null;
433    }
434    match col.data_type() {
435        DataType::Int64 => {
436            let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
437            uni_common::Value::Int(arr.value(row_idx))
438        }
439        DataType::Float64 => {
440            let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
441            uni_common::Value::Float(arr.value(row_idx))
442        }
443        DataType::Utf8 => {
444            let arr = col
445                .as_any()
446                .downcast_ref::<arrow_array::StringArray>()
447                .unwrap();
448            uni_common::Value::String(arr.value(row_idx).to_string())
449        }
450        DataType::Boolean => {
451            let arr = col
452                .as_any()
453                .downcast_ref::<arrow_array::BooleanArray>()
454                .unwrap();
455            uni_common::Value::Bool(arr.value(row_idx))
456        }
457        DataType::LargeBinary => {
458            let arr = col
459                .as_any()
460                .downcast_ref::<arrow_array::LargeBinaryArray>()
461                .unwrap();
462            let bytes = arr.value(row_idx);
463            uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null)
464        }
465        _ => uni_common::Value::Null,
466    }
467}
468
469// ---------------------------------------------------------------------------
470// Stream implementation
471// ---------------------------------------------------------------------------
472
473enum FoldStreamState {
474    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
475    Done,
476}
477
478struct FoldStream {
479    state: FoldStreamState,
480    schema: SchemaRef,
481    metrics: BaselineMetrics,
482}
483
484impl Stream for FoldStream {
485    type Item = DFResult<RecordBatch>;
486
487    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
488        match &mut self.state {
489            FoldStreamState::Running(fut) => match fut.as_mut().poll(cx) {
490                Poll::Ready(Ok(batch)) => {
491                    self.metrics.record_output(batch.num_rows());
492                    self.state = FoldStreamState::Done;
493                    Poll::Ready(Some(Ok(batch)))
494                }
495                Poll::Ready(Err(e)) => {
496                    self.state = FoldStreamState::Done;
497                    Poll::Ready(Some(Err(e)))
498                }
499                Poll::Pending => Poll::Pending,
500            },
501            FoldStreamState::Done => Poll::Ready(None),
502        }
503    }
504}
505
506impl RecordBatchStream for FoldStream {
507    fn schema(&self) -> SchemaRef {
508        Arc::clone(&self.schema)
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use arrow_array::{Float64Array, Int64Array, StringArray};
516    use arrow_schema::{DataType, Field, Schema};
517    use datafusion::physical_plan::memory::MemoryStream;
518    use datafusion::prelude::SessionContext;
519
520    fn make_test_batch(names: Vec<&str>, values: Vec<f64>) -> RecordBatch {
521        let schema = Arc::new(Schema::new(vec![
522            Field::new("name", DataType::Utf8, true),
523            Field::new("value", DataType::Float64, true),
524        ]));
525        RecordBatch::try_new(
526            schema,
527            vec![
528                Arc::new(StringArray::from(
529                    names.into_iter().map(Some).collect::<Vec<_>>(),
530                )),
531                Arc::new(Float64Array::from(values)),
532            ],
533        )
534        .unwrap()
535    }
536
537    fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
538        let schema = batch.schema();
539        Arc::new(TestMemoryExec {
540            batches: vec![batch],
541            schema: schema.clone(),
542            properties: compute_plan_properties(schema),
543        })
544    }
545
546    #[derive(Debug)]
547    struct TestMemoryExec {
548        batches: Vec<RecordBatch>,
549        schema: SchemaRef,
550        properties: PlanProperties,
551    }
552
553    impl DisplayAs for TestMemoryExec {
554        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
555            write!(f, "TestMemoryExec")
556        }
557    }
558
559    impl ExecutionPlan for TestMemoryExec {
560        fn name(&self) -> &str {
561            "TestMemoryExec"
562        }
563        fn as_any(&self) -> &dyn Any {
564            self
565        }
566        fn schema(&self) -> SchemaRef {
567            Arc::clone(&self.schema)
568        }
569        fn properties(&self) -> &PlanProperties {
570            &self.properties
571        }
572        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
573            vec![]
574        }
575        fn with_new_children(
576            self: Arc<Self>,
577            _children: Vec<Arc<dyn ExecutionPlan>>,
578        ) -> DFResult<Arc<dyn ExecutionPlan>> {
579            Ok(self)
580        }
581        fn execute(
582            &self,
583            _partition: usize,
584            _context: Arc<TaskContext>,
585        ) -> DFResult<SendableRecordBatchStream> {
586            Ok(Box::pin(MemoryStream::try_new(
587                self.batches.clone(),
588                Arc::clone(&self.schema),
589                None,
590            )?))
591        }
592    }
593
594    async fn execute_fold(
595        input: Arc<dyn ExecutionPlan>,
596        key_indices: Vec<usize>,
597        fold_bindings: Vec<FoldBinding>,
598    ) -> RecordBatch {
599        let exec = FoldExec::new(input, key_indices, fold_bindings);
600        let ctx = SessionContext::new();
601        let task_ctx = ctx.task_ctx();
602        let stream = exec.execute(0, task_ctx).unwrap();
603        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
604            .await
605            .unwrap();
606        if batches.is_empty() {
607            RecordBatch::new_empty(exec.schema())
608        } else {
609            arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
610        }
611    }
612
613    #[tokio::test]
614    async fn test_sum_single_group() {
615        let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
616        let input = make_memory_exec(batch);
617        let result = execute_fold(
618            input,
619            vec![0],
620            vec![FoldBinding {
621                output_name: "total".to_string(),
622                kind: FoldAggKind::Sum,
623                input_col_index: 1,
624            }],
625        )
626        .await;
627
628        assert_eq!(result.num_rows(), 1);
629        let totals = result
630            .column(1)
631            .as_any()
632            .downcast_ref::<Float64Array>()
633            .unwrap();
634        assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
635    }
636
637    #[tokio::test]
638    async fn test_count_non_null() {
639        let schema = Arc::new(Schema::new(vec![
640            Field::new("name", DataType::Utf8, true),
641            Field::new("value", DataType::Float64, true),
642        ]));
643        let batch = RecordBatch::try_new(
644            schema,
645            vec![
646                Arc::new(StringArray::from(vec![Some("a"), Some("a"), Some("a")])),
647                Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])),
648            ],
649        )
650        .unwrap();
651        let input = make_memory_exec(batch);
652        let result = execute_fold(
653            input,
654            vec![0],
655            vec![FoldBinding {
656                output_name: "cnt".to_string(),
657                kind: FoldAggKind::Count,
658                input_col_index: 1,
659            }],
660        )
661        .await;
662
663        assert_eq!(result.num_rows(), 1);
664        let counts = result
665            .column(1)
666            .as_any()
667            .downcast_ref::<Int64Array>()
668            .unwrap();
669        assert_eq!(counts.value(0), 2); // null not counted
670    }
671
672    #[tokio::test]
673    async fn test_max_min() {
674        let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 5.0]);
675        let input_max = make_memory_exec(batch.clone());
676        let input_min = make_memory_exec(batch);
677
678        let result_max = execute_fold(
679            input_max,
680            vec![0],
681            vec![FoldBinding {
682                output_name: "mx".to_string(),
683                kind: FoldAggKind::Max,
684                input_col_index: 1,
685            }],
686        )
687        .await;
688        let result_min = execute_fold(
689            input_min,
690            vec![0],
691            vec![FoldBinding {
692                output_name: "mn".to_string(),
693                kind: FoldAggKind::Min,
694                input_col_index: 1,
695            }],
696        )
697        .await;
698
699        let max_vals = result_max
700            .column(1)
701            .as_any()
702            .downcast_ref::<Float64Array>()
703            .unwrap();
704        assert_eq!(max_vals.value(0), 5.0);
705
706        let min_vals = result_min
707            .column(1)
708            .as_any()
709            .downcast_ref::<Float64Array>()
710            .unwrap();
711        assert_eq!(min_vals.value(0), 1.0);
712    }
713
714    #[tokio::test]
715    async fn test_avg() {
716        let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![2.0, 4.0, 6.0, 8.0]);
717        let input = make_memory_exec(batch);
718        let result = execute_fold(
719            input,
720            vec![0],
721            vec![FoldBinding {
722                output_name: "average".to_string(),
723                kind: FoldAggKind::Avg,
724                input_col_index: 1,
725            }],
726        )
727        .await;
728
729        assert_eq!(result.num_rows(), 1);
730        let avgs = result
731            .column(1)
732            .as_any()
733            .downcast_ref::<Float64Array>()
734            .unwrap();
735        assert!((avgs.value(0) - 5.0).abs() < f64::EPSILON);
736    }
737
738    #[tokio::test]
739    async fn test_multiple_groups() {
740        let batch = make_test_batch(
741            vec!["a", "a", "b", "b", "b"],
742            vec![1.0, 2.0, 10.0, 20.0, 30.0],
743        );
744        let input = make_memory_exec(batch);
745        let result = execute_fold(
746            input,
747            vec![0],
748            vec![FoldBinding {
749                output_name: "total".to_string(),
750                kind: FoldAggKind::Sum,
751                input_col_index: 1,
752            }],
753        )
754        .await;
755
756        assert_eq!(result.num_rows(), 2);
757        let names = result
758            .column(0)
759            .as_any()
760            .downcast_ref::<StringArray>()
761            .unwrap();
762        let totals = result
763            .column(1)
764            .as_any()
765            .downcast_ref::<Float64Array>()
766            .unwrap();
767
768        for i in 0..2 {
769            match names.value(i) {
770                "a" => assert!((totals.value(i) - 3.0).abs() < f64::EPSILON),
771                "b" => assert!((totals.value(i) - 60.0).abs() < f64::EPSILON),
772                _ => panic!("unexpected name"),
773            }
774        }
775    }
776
777    #[tokio::test]
778    async fn test_empty_input() {
779        let schema = Arc::new(Schema::new(vec![
780            Field::new("name", DataType::Utf8, true),
781            Field::new("value", DataType::Float64, true),
782        ]));
783        let batch = RecordBatch::new_empty(schema);
784        let input = make_memory_exec(batch);
785        let result = execute_fold(
786            input,
787            vec![0],
788            vec![FoldBinding {
789                output_name: "total".to_string(),
790                kind: FoldAggKind::Sum,
791                input_col_index: 1,
792            }],
793        )
794        .await;
795
796        assert_eq!(result.num_rows(), 0);
797    }
798
799    #[tokio::test]
800    async fn test_multiple_bindings() {
801        let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
802        let input = make_memory_exec(batch);
803        let result = execute_fold(
804            input,
805            vec![0],
806            vec![
807                FoldBinding {
808                    output_name: "total".to_string(),
809                    kind: FoldAggKind::Sum,
810                    input_col_index: 1,
811                },
812                FoldBinding {
813                    output_name: "cnt".to_string(),
814                    kind: FoldAggKind::Count,
815                    input_col_index: 1,
816                },
817                FoldBinding {
818                    output_name: "mx".to_string(),
819                    kind: FoldAggKind::Max,
820                    input_col_index: 1,
821                },
822            ],
823        )
824        .await;
825
826        assert_eq!(result.num_rows(), 1);
827        assert_eq!(result.num_columns(), 4); // name + total + cnt + mx
828
829        let totals = result
830            .column(1)
831            .as_any()
832            .downcast_ref::<Float64Array>()
833            .unwrap();
834        assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
835
836        let counts = result
837            .column(2)
838            .as_any()
839            .downcast_ref::<Int64Array>()
840            .unwrap();
841        assert_eq!(counts.value(0), 3);
842
843        let maxes = result
844            .column(3)
845            .as_any()
846            .downcast_ref::<Float64Array>()
847            .unwrap();
848        assert_eq!(maxes.value(0), 3.0);
849    }
850}