1use crate::query::df_graph::common::{ScalarKey, compute_plan_properties, extract_scalar_key};
10use arrow_array::builder::{Float64Builder, Int64Builder, LargeBinaryBuilder};
11use arrow_array::{Array, Float64Array, Int64Array, RecordBatch};
12use arrow_schema::{DataType, Field, Schema, SchemaRef};
13use datafusion::common::Result as DFResult;
14use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
15use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
16use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
17use futures::{Stream, TryStreamExt};
18use std::any::Any;
19use std::collections::HashMap;
20use std::fmt;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum FoldAggKind {
28 Sum,
29 Max,
30 Min,
31 Count,
32 Avg,
33 Collect,
34}
35
36#[derive(Debug, Clone)]
38pub struct FoldBinding {
39 pub output_name: String,
40 pub kind: FoldAggKind,
41 pub input_col_index: usize,
42}
43
44#[derive(Debug)]
49pub struct FoldExec {
50 input: Arc<dyn ExecutionPlan>,
51 key_indices: Vec<usize>,
52 fold_bindings: Vec<FoldBinding>,
53 schema: SchemaRef,
54 properties: PlanProperties,
55 metrics: ExecutionPlanMetricsSet,
56}
57
58impl FoldExec {
59 pub fn new(
66 input: Arc<dyn ExecutionPlan>,
67 key_indices: Vec<usize>,
68 fold_bindings: Vec<FoldBinding>,
69 ) -> Self {
70 let input_schema = input.schema();
71 let schema = Self::build_output_schema(&input_schema, &key_indices, &fold_bindings);
72 let properties = compute_plan_properties(Arc::clone(&schema));
73
74 Self {
75 input,
76 key_indices,
77 fold_bindings,
78 schema,
79 properties,
80 metrics: ExecutionPlanMetricsSet::new(),
81 }
82 }
83
84 fn build_output_schema(
85 input_schema: &SchemaRef,
86 key_indices: &[usize],
87 fold_bindings: &[FoldBinding],
88 ) -> SchemaRef {
89 let mut fields = Vec::new();
90
91 for &ki in key_indices {
93 fields.push(Arc::new(input_schema.field(ki).clone()));
94 }
95
96 for binding in fold_bindings {
98 let input_type = input_schema.field(binding.input_col_index).data_type();
99 let output_type = match binding.kind {
100 FoldAggKind::Sum | FoldAggKind::Avg => DataType::Float64,
101 FoldAggKind::Count => DataType::Int64,
102 FoldAggKind::Max | FoldAggKind::Min => input_type.clone(),
103 FoldAggKind::Collect => DataType::LargeBinary,
104 };
105 fields.push(Arc::new(Field::new(
106 &binding.output_name,
107 output_type,
108 true,
109 )));
110 }
111
112 Arc::new(Schema::new(fields))
113 }
114}
115
116impl DisplayAs for FoldExec {
117 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118 write!(
119 f,
120 "FoldExec: key_indices={:?}, bindings={:?}",
121 self.key_indices, self.fold_bindings
122 )
123 }
124}
125
126impl ExecutionPlan for FoldExec {
127 fn name(&self) -> &str {
128 "FoldExec"
129 }
130
131 fn as_any(&self) -> &dyn Any {
132 self
133 }
134
135 fn schema(&self) -> SchemaRef {
136 Arc::clone(&self.schema)
137 }
138
139 fn properties(&self) -> &PlanProperties {
140 &self.properties
141 }
142
143 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
144 vec![&self.input]
145 }
146
147 fn with_new_children(
148 self: Arc<Self>,
149 children: Vec<Arc<dyn ExecutionPlan>>,
150 ) -> DFResult<Arc<dyn ExecutionPlan>> {
151 if children.len() != 1 {
152 return Err(datafusion::error::DataFusionError::Plan(
153 "FoldExec requires exactly one child".to_string(),
154 ));
155 }
156 Ok(Arc::new(Self::new(
157 Arc::clone(&children[0]),
158 self.key_indices.clone(),
159 self.fold_bindings.clone(),
160 )))
161 }
162
163 fn execute(
164 &self,
165 partition: usize,
166 context: Arc<TaskContext>,
167 ) -> DFResult<SendableRecordBatchStream> {
168 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
169 let metrics = BaselineMetrics::new(&self.metrics, partition);
170 let key_indices = self.key_indices.clone();
171 let fold_bindings = self.fold_bindings.clone();
172 let output_schema = Arc::clone(&self.schema);
173 let input_schema = self.input.schema();
174
175 let fut = async move {
176 let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
177
178 if batches.is_empty() {
179 return Ok(RecordBatch::new_empty(output_schema));
180 }
181
182 let batch = arrow::compute::concat_batches(&input_schema, &batches)
183 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
184
185 if batch.num_rows() == 0 {
186 return Ok(RecordBatch::new_empty(output_schema));
187 }
188
189 let mut groups: HashMap<Vec<ScalarKey>, Vec<usize>> = HashMap::new();
191 for row_idx in 0..batch.num_rows() {
192 let key = extract_scalar_key(&batch, &key_indices, row_idx);
193 groups.entry(key).or_default().push(row_idx);
194 }
195
196 let mut ordered_keys: Vec<Vec<ScalarKey>> = Vec::new();
198 {
199 let mut seen: std::collections::HashSet<Vec<ScalarKey>> =
200 std::collections::HashSet::new();
201 for row_idx in 0..batch.num_rows() {
202 let key = extract_scalar_key(&batch, &key_indices, row_idx);
203 if seen.insert(key.clone()) {
204 ordered_keys.push(key);
205 }
206 }
207 }
208
209 let num_groups = ordered_keys.len();
210
211 let mut output_columns: Vec<arrow_array::ArrayRef> = Vec::new();
213
214 for &ki in &key_indices {
216 let col = batch.column(ki);
217 let first_indices: Vec<u32> =
218 ordered_keys.iter().map(|k| groups[k][0] as u32).collect();
219 let idx_array = arrow_array::UInt32Array::from(first_indices);
220 let taken = arrow::compute::take(col.as_ref(), &idx_array, None).map_err(|e| {
221 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
222 })?;
223 output_columns.push(taken);
224 }
225
226 for binding in &fold_bindings {
228 let col = batch.column(binding.input_col_index);
229 let agg_col = compute_fold_aggregate(
230 col.as_ref(),
231 &binding.kind,
232 &ordered_keys,
233 &groups,
234 num_groups,
235 )?;
236 output_columns.push(agg_col);
237 }
238
239 RecordBatch::try_new(output_schema, output_columns)
240 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
241 };
242
243 Ok(Box::pin(FoldStream {
244 state: FoldStreamState::Running(Box::pin(fut)),
245 schema: Arc::clone(&self.schema),
246 metrics,
247 }))
248 }
249
250 fn metrics(&self) -> Option<MetricsSet> {
251 Some(self.metrics.clone_inner())
252 }
253}
254
255fn compute_fold_aggregate(
260 col: &dyn Array,
261 kind: &FoldAggKind,
262 ordered_keys: &[Vec<ScalarKey>],
263 groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
264 num_groups: usize,
265) -> DFResult<arrow_array::ArrayRef> {
266 match kind {
267 FoldAggKind::Sum => {
268 let mut builder = Float64Builder::with_capacity(num_groups);
269 for key in ordered_keys {
270 let indices = &groups[key];
271 let sum = sum_f64(col, indices);
272 match sum {
273 Some(v) => builder.append_value(v),
274 None => builder.append_null(),
275 }
276 }
277 Ok(Arc::new(builder.finish()))
278 }
279 FoldAggKind::Count => {
280 let mut builder = Int64Builder::with_capacity(num_groups);
281 for key in ordered_keys {
282 let indices = &groups[key];
283 let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
284 builder.append_value(count as i64);
285 }
286 Ok(Arc::new(builder.finish()))
287 }
288 FoldAggKind::Max => compute_minmax(col, ordered_keys, groups, num_groups, false),
289 FoldAggKind::Min => compute_minmax(col, ordered_keys, groups, num_groups, true),
290 FoldAggKind::Avg => {
291 let mut builder = Float64Builder::with_capacity(num_groups);
292 for key in ordered_keys {
293 let indices = &groups[key];
294 let sum = sum_f64(col, indices);
295 let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
296 match (sum, count) {
297 (Some(s), c) if c > 0 => builder.append_value(s / c as f64),
298 _ => builder.append_null(),
299 }
300 }
301 Ok(Arc::new(builder.finish()))
302 }
303 FoldAggKind::Collect => {
304 let mut builder = LargeBinaryBuilder::with_capacity(num_groups, num_groups * 32);
305 for key in ordered_keys {
306 let indices = &groups[key];
307 let mut values = Vec::new();
308 for &i in indices {
309 if !col.is_null(i) {
310 let val = scalar_to_value(col, i);
311 values.push(val);
312 }
313 }
314 let list = uni_common::Value::List(values);
315 let encoded = uni_common::cypher_value_codec::encode(&list);
316 builder.append_value(&encoded);
317 }
318 Ok(Arc::new(builder.finish()))
319 }
320 }
321}
322
323fn sum_f64(col: &dyn Array, indices: &[usize]) -> Option<f64> {
324 let mut sum = 0.0;
325 let mut has_value = false;
326 for &i in indices {
327 if col.is_null(i) {
328 continue;
329 }
330 has_value = true;
331 if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
332 sum += arr.value(i);
333 } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
334 sum += arr.value(i) as f64;
335 }
336 }
337 if has_value { Some(sum) } else { None }
338}
339
340fn compute_minmax(
341 col: &dyn Array,
342 ordered_keys: &[Vec<ScalarKey>],
343 groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
344 num_groups: usize,
345 is_min: bool,
346) -> DFResult<arrow_array::ArrayRef> {
347 match col.data_type() {
348 DataType::Int64 => {
349 let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
350 let mut builder = Int64Builder::with_capacity(num_groups);
351 for key in ordered_keys {
352 let indices = &groups[key];
353 let mut result: Option<i64> = None;
354 for &i in indices {
355 if arr.is_null(i) {
356 continue;
357 }
358 let v = arr.value(i);
359 result = Some(match result {
360 None => v,
361 Some(cur) if is_min => cur.min(v),
362 Some(cur) => cur.max(v),
363 });
364 }
365 match result {
366 Some(v) => builder.append_value(v),
367 None => builder.append_null(),
368 }
369 }
370 Ok(Arc::new(builder.finish()))
371 }
372 DataType::Float64 => {
373 let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
374 let mut builder = Float64Builder::with_capacity(num_groups);
375 for key in ordered_keys {
376 let indices = &groups[key];
377 let mut result: Option<f64> = None;
378 for &i in indices {
379 if arr.is_null(i) {
380 continue;
381 }
382 let v = arr.value(i);
383 result = Some(match result {
384 None => v,
385 Some(cur) if is_min => cur.min(v),
386 Some(cur) => cur.max(v),
387 });
388 }
389 match result {
390 Some(v) => builder.append_value(v),
391 None => builder.append_null(),
392 }
393 }
394 Ok(Arc::new(builder.finish()))
395 }
396 _ => {
397 let mut builder = arrow_array::builder::StringBuilder::new();
399 for key in ordered_keys {
400 let indices = &groups[key];
401 let mut result: Option<String> = None;
402 for &i in indices {
403 if col.is_null(i) {
404 continue;
405 }
406 let v = format!("{:?}", scalar_to_value(col, i));
407 result = Some(match result {
408 None => v.clone(),
409 Some(cur) => {
410 if is_min {
411 if v < cur { v } else { cur }
412 } else if v > cur {
413 v
414 } else {
415 cur
416 }
417 }
418 });
419 }
420 match result {
421 Some(v) => builder.append_value(&v),
422 None => builder.append_null(),
423 }
424 }
425 Ok(Arc::new(builder.finish()))
426 }
427 }
428}
429
430fn scalar_to_value(col: &dyn Array, row_idx: usize) -> uni_common::Value {
431 if col.is_null(row_idx) {
432 return uni_common::Value::Null;
433 }
434 match col.data_type() {
435 DataType::Int64 => {
436 let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
437 uni_common::Value::Int(arr.value(row_idx))
438 }
439 DataType::Float64 => {
440 let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
441 uni_common::Value::Float(arr.value(row_idx))
442 }
443 DataType::Utf8 => {
444 let arr = col
445 .as_any()
446 .downcast_ref::<arrow_array::StringArray>()
447 .unwrap();
448 uni_common::Value::String(arr.value(row_idx).to_string())
449 }
450 DataType::Boolean => {
451 let arr = col
452 .as_any()
453 .downcast_ref::<arrow_array::BooleanArray>()
454 .unwrap();
455 uni_common::Value::Bool(arr.value(row_idx))
456 }
457 DataType::LargeBinary => {
458 let arr = col
459 .as_any()
460 .downcast_ref::<arrow_array::LargeBinaryArray>()
461 .unwrap();
462 let bytes = arr.value(row_idx);
463 uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null)
464 }
465 _ => uni_common::Value::Null,
466 }
467}
468
469enum FoldStreamState {
474 Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
475 Done,
476}
477
478struct FoldStream {
479 state: FoldStreamState,
480 schema: SchemaRef,
481 metrics: BaselineMetrics,
482}
483
484impl Stream for FoldStream {
485 type Item = DFResult<RecordBatch>;
486
487 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
488 match &mut self.state {
489 FoldStreamState::Running(fut) => match fut.as_mut().poll(cx) {
490 Poll::Ready(Ok(batch)) => {
491 self.metrics.record_output(batch.num_rows());
492 self.state = FoldStreamState::Done;
493 Poll::Ready(Some(Ok(batch)))
494 }
495 Poll::Ready(Err(e)) => {
496 self.state = FoldStreamState::Done;
497 Poll::Ready(Some(Err(e)))
498 }
499 Poll::Pending => Poll::Pending,
500 },
501 FoldStreamState::Done => Poll::Ready(None),
502 }
503 }
504}
505
506impl RecordBatchStream for FoldStream {
507 fn schema(&self) -> SchemaRef {
508 Arc::clone(&self.schema)
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use arrow_array::{Float64Array, Int64Array, StringArray};
516 use arrow_schema::{DataType, Field, Schema};
517 use datafusion::physical_plan::memory::MemoryStream;
518 use datafusion::prelude::SessionContext;
519
520 fn make_test_batch(names: Vec<&str>, values: Vec<f64>) -> RecordBatch {
521 let schema = Arc::new(Schema::new(vec![
522 Field::new("name", DataType::Utf8, true),
523 Field::new("value", DataType::Float64, true),
524 ]));
525 RecordBatch::try_new(
526 schema,
527 vec![
528 Arc::new(StringArray::from(
529 names.into_iter().map(Some).collect::<Vec<_>>(),
530 )),
531 Arc::new(Float64Array::from(values)),
532 ],
533 )
534 .unwrap()
535 }
536
537 fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
538 let schema = batch.schema();
539 Arc::new(TestMemoryExec {
540 batches: vec![batch],
541 schema: schema.clone(),
542 properties: compute_plan_properties(schema),
543 })
544 }
545
546 #[derive(Debug)]
547 struct TestMemoryExec {
548 batches: Vec<RecordBatch>,
549 schema: SchemaRef,
550 properties: PlanProperties,
551 }
552
553 impl DisplayAs for TestMemoryExec {
554 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
555 write!(f, "TestMemoryExec")
556 }
557 }
558
559 impl ExecutionPlan for TestMemoryExec {
560 fn name(&self) -> &str {
561 "TestMemoryExec"
562 }
563 fn as_any(&self) -> &dyn Any {
564 self
565 }
566 fn schema(&self) -> SchemaRef {
567 Arc::clone(&self.schema)
568 }
569 fn properties(&self) -> &PlanProperties {
570 &self.properties
571 }
572 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
573 vec![]
574 }
575 fn with_new_children(
576 self: Arc<Self>,
577 _children: Vec<Arc<dyn ExecutionPlan>>,
578 ) -> DFResult<Arc<dyn ExecutionPlan>> {
579 Ok(self)
580 }
581 fn execute(
582 &self,
583 _partition: usize,
584 _context: Arc<TaskContext>,
585 ) -> DFResult<SendableRecordBatchStream> {
586 Ok(Box::pin(MemoryStream::try_new(
587 self.batches.clone(),
588 Arc::clone(&self.schema),
589 None,
590 )?))
591 }
592 }
593
594 async fn execute_fold(
595 input: Arc<dyn ExecutionPlan>,
596 key_indices: Vec<usize>,
597 fold_bindings: Vec<FoldBinding>,
598 ) -> RecordBatch {
599 let exec = FoldExec::new(input, key_indices, fold_bindings);
600 let ctx = SessionContext::new();
601 let task_ctx = ctx.task_ctx();
602 let stream = exec.execute(0, task_ctx).unwrap();
603 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
604 .await
605 .unwrap();
606 if batches.is_empty() {
607 RecordBatch::new_empty(exec.schema())
608 } else {
609 arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
610 }
611 }
612
613 #[tokio::test]
614 async fn test_sum_single_group() {
615 let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
616 let input = make_memory_exec(batch);
617 let result = execute_fold(
618 input,
619 vec![0],
620 vec![FoldBinding {
621 output_name: "total".to_string(),
622 kind: FoldAggKind::Sum,
623 input_col_index: 1,
624 }],
625 )
626 .await;
627
628 assert_eq!(result.num_rows(), 1);
629 let totals = result
630 .column(1)
631 .as_any()
632 .downcast_ref::<Float64Array>()
633 .unwrap();
634 assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
635 }
636
637 #[tokio::test]
638 async fn test_count_non_null() {
639 let schema = Arc::new(Schema::new(vec![
640 Field::new("name", DataType::Utf8, true),
641 Field::new("value", DataType::Float64, true),
642 ]));
643 let batch = RecordBatch::try_new(
644 schema,
645 vec![
646 Arc::new(StringArray::from(vec![Some("a"), Some("a"), Some("a")])),
647 Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])),
648 ],
649 )
650 .unwrap();
651 let input = make_memory_exec(batch);
652 let result = execute_fold(
653 input,
654 vec![0],
655 vec![FoldBinding {
656 output_name: "cnt".to_string(),
657 kind: FoldAggKind::Count,
658 input_col_index: 1,
659 }],
660 )
661 .await;
662
663 assert_eq!(result.num_rows(), 1);
664 let counts = result
665 .column(1)
666 .as_any()
667 .downcast_ref::<Int64Array>()
668 .unwrap();
669 assert_eq!(counts.value(0), 2); }
671
672 #[tokio::test]
673 async fn test_max_min() {
674 let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 5.0]);
675 let input_max = make_memory_exec(batch.clone());
676 let input_min = make_memory_exec(batch);
677
678 let result_max = execute_fold(
679 input_max,
680 vec![0],
681 vec![FoldBinding {
682 output_name: "mx".to_string(),
683 kind: FoldAggKind::Max,
684 input_col_index: 1,
685 }],
686 )
687 .await;
688 let result_min = execute_fold(
689 input_min,
690 vec![0],
691 vec![FoldBinding {
692 output_name: "mn".to_string(),
693 kind: FoldAggKind::Min,
694 input_col_index: 1,
695 }],
696 )
697 .await;
698
699 let max_vals = result_max
700 .column(1)
701 .as_any()
702 .downcast_ref::<Float64Array>()
703 .unwrap();
704 assert_eq!(max_vals.value(0), 5.0);
705
706 let min_vals = result_min
707 .column(1)
708 .as_any()
709 .downcast_ref::<Float64Array>()
710 .unwrap();
711 assert_eq!(min_vals.value(0), 1.0);
712 }
713
714 #[tokio::test]
715 async fn test_avg() {
716 let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![2.0, 4.0, 6.0, 8.0]);
717 let input = make_memory_exec(batch);
718 let result = execute_fold(
719 input,
720 vec![0],
721 vec![FoldBinding {
722 output_name: "average".to_string(),
723 kind: FoldAggKind::Avg,
724 input_col_index: 1,
725 }],
726 )
727 .await;
728
729 assert_eq!(result.num_rows(), 1);
730 let avgs = result
731 .column(1)
732 .as_any()
733 .downcast_ref::<Float64Array>()
734 .unwrap();
735 assert!((avgs.value(0) - 5.0).abs() < f64::EPSILON);
736 }
737
738 #[tokio::test]
739 async fn test_multiple_groups() {
740 let batch = make_test_batch(
741 vec!["a", "a", "b", "b", "b"],
742 vec![1.0, 2.0, 10.0, 20.0, 30.0],
743 );
744 let input = make_memory_exec(batch);
745 let result = execute_fold(
746 input,
747 vec![0],
748 vec![FoldBinding {
749 output_name: "total".to_string(),
750 kind: FoldAggKind::Sum,
751 input_col_index: 1,
752 }],
753 )
754 .await;
755
756 assert_eq!(result.num_rows(), 2);
757 let names = result
758 .column(0)
759 .as_any()
760 .downcast_ref::<StringArray>()
761 .unwrap();
762 let totals = result
763 .column(1)
764 .as_any()
765 .downcast_ref::<Float64Array>()
766 .unwrap();
767
768 for i in 0..2 {
769 match names.value(i) {
770 "a" => assert!((totals.value(i) - 3.0).abs() < f64::EPSILON),
771 "b" => assert!((totals.value(i) - 60.0).abs() < f64::EPSILON),
772 _ => panic!("unexpected name"),
773 }
774 }
775 }
776
777 #[tokio::test]
778 async fn test_empty_input() {
779 let schema = Arc::new(Schema::new(vec![
780 Field::new("name", DataType::Utf8, true),
781 Field::new("value", DataType::Float64, true),
782 ]));
783 let batch = RecordBatch::new_empty(schema);
784 let input = make_memory_exec(batch);
785 let result = execute_fold(
786 input,
787 vec![0],
788 vec![FoldBinding {
789 output_name: "total".to_string(),
790 kind: FoldAggKind::Sum,
791 input_col_index: 1,
792 }],
793 )
794 .await;
795
796 assert_eq!(result.num_rows(), 0);
797 }
798
799 #[tokio::test]
800 async fn test_multiple_bindings() {
801 let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
802 let input = make_memory_exec(batch);
803 let result = execute_fold(
804 input,
805 vec![0],
806 vec![
807 FoldBinding {
808 output_name: "total".to_string(),
809 kind: FoldAggKind::Sum,
810 input_col_index: 1,
811 },
812 FoldBinding {
813 output_name: "cnt".to_string(),
814 kind: FoldAggKind::Count,
815 input_col_index: 1,
816 },
817 FoldBinding {
818 output_name: "mx".to_string(),
819 kind: FoldAggKind::Max,
820 input_col_index: 1,
821 },
822 ],
823 )
824 .await;
825
826 assert_eq!(result.num_rows(), 1);
827 assert_eq!(result.num_columns(), 4); let totals = result
830 .column(1)
831 .as_any()
832 .downcast_ref::<Float64Array>()
833 .unwrap();
834 assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
835
836 let counts = result
837 .column(2)
838 .as_any()
839 .downcast_ref::<Int64Array>()
840 .unwrap();
841 assert_eq!(counts.value(0), 3);
842
843 let maxes = result
844 .column(3)
845 .as_any()
846 .downcast_ref::<Float64Array>()
847 .unwrap();
848 assert_eq!(maxes.value(0), 3.0);
849 }
850}