1use crate::query::df_graph::common::{
10 ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
11};
12use arrow_array::builder::{Float64Builder, Int64Builder, LargeBinaryBuilder};
13use arrow_array::{Array, Float64Array, Int64Array, RecordBatch};
14use arrow_schema::{DataType, Field, Schema, SchemaRef};
15use datafusion::common::Result as DFResult;
16use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
17use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
18use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
19use futures::{Stream, TryStreamExt};
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MonotonicDirection {
30 NonDecreasing,
32 NonIncreasing,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum FoldAggKind {
39 Sum,
40 Max,
41 Min,
42 Count,
43 CountAll,
45 Avg,
46 Collect,
47 Nor, Prod, }
50
51impl FoldAggKind {
52 pub fn is_monotonic(&self) -> bool {
54 matches!(
55 self,
56 Self::Sum
57 | Self::Max
58 | Self::Min
59 | Self::Count
60 | Self::CountAll
61 | Self::Nor
62 | Self::Prod
63 )
64 }
65
66 pub fn monotonicity_direction(&self) -> Option<MonotonicDirection> {
68 match self {
69 Self::Sum | Self::Max | Self::Count | Self::CountAll | Self::Nor => {
70 Some(MonotonicDirection::NonDecreasing)
71 }
72 Self::Min | Self::Prod => Some(MonotonicDirection::NonIncreasing),
73 Self::Avg | Self::Collect => None,
74 }
75 }
76
77 pub fn identity(&self) -> Option<f64> {
79 match self {
80 Self::Sum | Self::Count | Self::CountAll | Self::Nor => Some(0.0),
81 Self::Max => Some(f64::NEG_INFINITY),
82 Self::Min => Some(f64::INFINITY),
83 Self::Prod => Some(1.0),
84 Self::Avg | Self::Collect => None,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct FoldBinding {
92 pub output_name: String,
93 pub kind: FoldAggKind,
94 pub input_col_index: usize,
95 pub input_col_name: Option<String>,
98}
99
100#[derive(Debug)]
105pub struct FoldExec {
106 input: Arc<dyn ExecutionPlan>,
107 key_indices: Vec<usize>,
108 fold_bindings: Vec<FoldBinding>,
109 strict_probability_domain: bool,
110 probability_epsilon: f64,
111 schema: SchemaRef,
112 properties: PlanProperties,
113 metrics: ExecutionPlanMetricsSet,
114}
115
116impl FoldExec {
117 pub fn new(
124 input: Arc<dyn ExecutionPlan>,
125 key_indices: Vec<usize>,
126 fold_bindings: Vec<FoldBinding>,
127 strict_probability_domain: bool,
128 probability_epsilon: f64,
129 ) -> Self {
130 let input_schema = input.schema();
131 let schema = Self::build_output_schema(&input_schema, &key_indices, &fold_bindings);
132 let properties = compute_plan_properties(Arc::clone(&schema));
133
134 Self {
135 input,
136 key_indices,
137 fold_bindings,
138 strict_probability_domain,
139 probability_epsilon,
140 schema,
141 properties,
142 metrics: ExecutionPlanMetricsSet::new(),
143 }
144 }
145
146 fn build_output_schema(
147 input_schema: &SchemaRef,
148 key_indices: &[usize],
149 fold_bindings: &[FoldBinding],
150 ) -> SchemaRef {
151 let mut fields = Vec::new();
152
153 for &ki in key_indices {
155 fields.push(Arc::new(input_schema.field(ki).clone()));
156 }
157
158 for binding in fold_bindings {
160 let output_type = match binding.kind {
161 FoldAggKind::Sum | FoldAggKind::Avg | FoldAggKind::Nor | FoldAggKind::Prod => {
162 DataType::Float64
163 }
164 FoldAggKind::Count | FoldAggKind::CountAll => DataType::Int64,
165 FoldAggKind::Max | FoldAggKind::Min => {
166 let idx = binding
167 .input_col_name
168 .as_ref()
169 .and_then(|name| input_schema.index_of(name).ok())
170 .unwrap_or(binding.input_col_index);
171 if idx < input_schema.fields().len() {
172 input_schema.field(idx).data_type().clone()
173 } else {
174 DataType::Float64
175 }
176 }
177 FoldAggKind::Collect => DataType::LargeBinary,
178 };
179 fields.push(Arc::new(Field::new(
180 &binding.output_name,
181 output_type,
182 true,
183 )));
184 }
185
186 Arc::new(Schema::new(fields))
187 }
188}
189
190impl DisplayAs for FoldExec {
191 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192 write!(
193 f,
194 "FoldExec: key_indices={:?}, bindings={:?}",
195 self.key_indices, self.fold_bindings
196 )
197 }
198}
199
200impl ExecutionPlan for FoldExec {
201 fn name(&self) -> &str {
202 "FoldExec"
203 }
204
205 fn as_any(&self) -> &dyn Any {
206 self
207 }
208
209 fn schema(&self) -> SchemaRef {
210 Arc::clone(&self.schema)
211 }
212
213 fn properties(&self) -> &PlanProperties {
214 &self.properties
215 }
216
217 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
218 vec![&self.input]
219 }
220
221 fn with_new_children(
222 self: Arc<Self>,
223 children: Vec<Arc<dyn ExecutionPlan>>,
224 ) -> DFResult<Arc<dyn ExecutionPlan>> {
225 if children.len() != 1 {
226 return Err(datafusion::error::DataFusionError::Plan(
227 "FoldExec requires exactly one child".to_string(),
228 ));
229 }
230 Ok(Arc::new(Self::new(
231 Arc::clone(&children[0]),
232 self.key_indices.clone(),
233 self.fold_bindings.clone(),
234 self.strict_probability_domain,
235 self.probability_epsilon,
236 )))
237 }
238
239 fn execute(
240 &self,
241 partition: usize,
242 context: Arc<TaskContext>,
243 ) -> DFResult<SendableRecordBatchStream> {
244 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
245 let metrics = BaselineMetrics::new(&self.metrics, partition);
246 let key_indices = self.key_indices.clone();
247 let fold_bindings = self.fold_bindings.clone();
248 let strict = self.strict_probability_domain;
249 let epsilon = self.probability_epsilon;
250 let output_schema = Arc::clone(&self.schema);
251 let input_schema = self.input.schema();
252
253 let fut = async move {
254 let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
255
256 if batches.is_empty() {
257 return Ok(RecordBatch::new_empty(output_schema));
258 }
259
260 let actual_schema = batches
263 .first()
264 .map(|b| b.schema())
265 .unwrap_or(input_schema.clone());
266 let batch =
267 arrow::compute::concat_batches(&actual_schema, &batches).map_err(arrow_err)?;
268
269 if batch.num_rows() == 0 {
270 return Ok(RecordBatch::new_empty(output_schema));
271 }
272
273 let mut groups: HashMap<Vec<ScalarKey>, Vec<usize>> = HashMap::new();
275 let mut ordered_keys: Vec<Vec<ScalarKey>> = Vec::new();
276 for row_idx in 0..batch.num_rows() {
277 let key = extract_scalar_key(&batch, &key_indices, row_idx);
278 let entry = groups.entry(key.clone());
279 if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
280 ordered_keys.push(key);
281 }
282 entry.or_default().push(row_idx);
283 }
284
285 let num_groups = ordered_keys.len();
286
287 let mut output_columns: Vec<arrow_array::ArrayRef> = Vec::new();
289
290 for &ki in &key_indices {
292 if ki >= batch.num_columns() {
293 continue; }
295 let col = batch.column(ki);
296 let first_indices: Vec<u32> =
297 ordered_keys.iter().map(|k| groups[k][0] as u32).collect();
298 let idx_array = arrow_array::UInt32Array::from(first_indices);
299 let taken = arrow::compute::take(col.as_ref(), &idx_array, None).map_err(|e| {
300 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
301 })?;
302 output_columns.push(taken);
303 }
304
305 for binding in &fold_bindings {
307 let col: Arc<dyn Array> = if binding.kind == FoldAggKind::CountAll {
308 Arc::new(arrow_array::Int64Array::from(vec![0i64; batch.num_rows()]))
310 } else {
311 let resolved_idx = binding
313 .input_col_name
314 .as_ref()
315 .and_then(|name| batch.schema().index_of(name).ok())
316 .unwrap_or(binding.input_col_index);
317 if resolved_idx < batch.num_columns() {
318 Arc::clone(batch.column(resolved_idx))
319 } else {
320 Arc::new(arrow_array::Float64Array::from(vec![
322 0.0f64;
323 batch.num_rows()
324 ]))
325 }
326 };
327 let agg_col = compute_fold_aggregate(
328 col.as_ref(),
329 &binding.kind,
330 &ordered_keys,
331 &groups,
332 num_groups,
333 strict,
334 epsilon,
335 )?;
336 output_columns.push(agg_col);
337 }
338
339 RecordBatch::try_new(output_schema, output_columns).map_err(arrow_err)
340 };
341
342 Ok(Box::pin(FoldStream {
343 state: FoldStreamState::Running(Box::pin(fut)),
344 schema: Arc::clone(&self.schema),
345 metrics,
346 }))
347 }
348
349 fn metrics(&self) -> Option<MetricsSet> {
350 Some(self.metrics.clone_inner())
351 }
352}
353
354fn compute_fold_aggregate(
359 col: &dyn Array,
360 kind: &FoldAggKind,
361 ordered_keys: &[Vec<ScalarKey>],
362 groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
363 num_groups: usize,
364 strict: bool,
365 probability_epsilon: f64,
366) -> DFResult<arrow_array::ArrayRef> {
367 match kind {
368 FoldAggKind::Sum => {
369 let mut builder = Float64Builder::with_capacity(num_groups);
370 for key in ordered_keys {
371 builder.append_option(sum_f64(col, &groups[key]));
372 }
373 Ok(Arc::new(builder.finish()))
374 }
375 FoldAggKind::Count => {
376 let mut builder = Int64Builder::with_capacity(num_groups);
377 for key in ordered_keys {
378 let indices = &groups[key];
379 let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
380 builder.append_value(count as i64);
381 }
382 Ok(Arc::new(builder.finish()))
383 }
384 FoldAggKind::CountAll => {
385 let mut builder = Int64Builder::with_capacity(num_groups);
386 for key in ordered_keys {
387 let indices = &groups[key];
388 builder.append_value(indices.len() as i64);
389 }
390 Ok(Arc::new(builder.finish()))
391 }
392 FoldAggKind::Max => compute_minmax(col, ordered_keys, groups, num_groups, false),
393 FoldAggKind::Min => compute_minmax(col, ordered_keys, groups, num_groups, true),
394 FoldAggKind::Avg => {
395 let mut builder = Float64Builder::with_capacity(num_groups);
396 for key in ordered_keys {
397 let indices = &groups[key];
398 let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
399 let avg = sum_f64(col, indices)
400 .filter(|_| count > 0)
401 .map(|s| s / count as f64);
402 builder.append_option(avg);
403 }
404 Ok(Arc::new(builder.finish()))
405 }
406 FoldAggKind::Collect => {
407 let mut builder = LargeBinaryBuilder::with_capacity(num_groups, num_groups * 32);
408 for key in ordered_keys {
409 let values: Vec<uni_common::Value> = groups[key]
410 .iter()
411 .filter(|&&i| !col.is_null(i))
412 .map(|&i| scalar_to_value(col, i))
413 .collect();
414 let encoded =
415 uni_common::cypher_value_codec::encode(&uni_common::Value::List(values));
416 builder.append_value(&encoded);
417 }
418 Ok(Arc::new(builder.finish()))
419 }
420 FoldAggKind::Nor => {
421 let mut builder = Float64Builder::with_capacity(num_groups);
422 for key in ordered_keys {
423 let indices = &groups[key];
424 builder.append_option(noisy_or_f64(col, indices, strict)?);
425 }
426 Ok(Arc::new(builder.finish()))
427 }
428 FoldAggKind::Prod => {
429 let mut builder = Float64Builder::with_capacity(num_groups);
430 for key in ordered_keys {
431 builder.append_option(product_f64(col, &groups[key], strict, probability_epsilon)?);
432 }
433 Ok(Arc::new(builder.finish()))
434 }
435 }
436}
437
438fn sum_f64(col: &dyn Array, indices: &[usize]) -> Option<f64> {
439 let mut sum = 0.0;
440 let mut has_value = false;
441 for &i in indices {
442 if col.is_null(i) {
443 continue;
444 }
445 has_value = true;
446 if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
447 sum += arr.value(i);
448 } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
449 sum += arr.value(i) as f64;
450 }
451 }
452 if has_value { Some(sum) } else { None }
453}
454
455fn noisy_or_f64(col: &dyn Array, indices: &[usize], strict: bool) -> DFResult<Option<f64>> {
457 let mut complement_product = 1.0;
458 let mut has_value = false;
459 for &i in indices {
460 if col.is_null(i) {
461 continue;
462 }
463 has_value = true;
464 let raw = if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
465 arr.value(i)
466 } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
467 arr.value(i) as f64
468 } else {
469 continue;
470 };
471 if strict && !(0.0..=1.0).contains(&raw) {
472 return Err(datafusion::error::DataFusionError::Execution(format!(
473 "strict_probability_domain: MNOR input {raw} is outside [0, 1]"
474 )));
475 }
476 if !strict && !(0.0..=1.0).contains(&raw) {
477 tracing::warn!(
478 "MNOR input {raw} outside [0,1], clamped to {}",
479 raw.clamp(0.0, 1.0)
480 );
481 }
482 let p = raw.clamp(0.0, 1.0);
483 complement_product *= 1.0 - p;
484 }
485 if has_value {
486 Ok(Some(1.0 - complement_product))
487 } else {
488 Ok(None)
489 }
490}
491
492fn product_f64(
497 col: &dyn Array,
498 indices: &[usize],
499 strict: bool,
500 probability_epsilon: f64,
501) -> DFResult<Option<f64>> {
502 let mut product = 1.0;
503 let mut log_sum = 0.0;
504 let mut use_log = false;
505 let mut has_value = false;
506 for &i in indices {
507 if col.is_null(i) {
508 continue;
509 }
510 has_value = true;
511 let raw = if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
512 arr.value(i)
513 } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
514 arr.value(i) as f64
515 } else {
516 continue;
517 };
518 if strict && !(0.0..=1.0).contains(&raw) {
519 return Err(datafusion::error::DataFusionError::Execution(format!(
520 "strict_probability_domain: MPROD input {raw} is outside [0, 1]"
521 )));
522 }
523 if !strict && !(0.0..=1.0).contains(&raw) {
524 tracing::warn!(
525 "MPROD input {raw} outside [0,1], clamped to {}",
526 raw.clamp(0.0, 1.0)
527 );
528 }
529 let p = raw.clamp(0.0, 1.0);
530 if p == 0.0 {
531 return Ok(Some(0.0));
532 }
533 if use_log {
534 log_sum += p.ln();
535 } else {
536 product *= p;
537 if product < probability_epsilon {
538 log_sum = product.ln();
540 use_log = true;
541 }
542 }
543 }
544 if !has_value {
545 return Ok(None);
546 }
547 if use_log {
548 Ok(Some(log_sum.exp()))
549 } else {
550 Ok(Some(product))
551 }
552}
553
554fn compute_minmax(
555 col: &dyn Array,
556 ordered_keys: &[Vec<ScalarKey>],
557 groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
558 num_groups: usize,
559 is_min: bool,
560) -> DFResult<arrow_array::ArrayRef> {
561 match col.data_type() {
562 DataType::Int64 => {
563 let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
564 let mut builder = Int64Builder::with_capacity(num_groups);
565 for key in ordered_keys {
566 let mut result: Option<i64> = None;
567 for &i in &groups[key] {
568 if !arr.is_null(i) {
569 let v = arr.value(i);
570 result = Some(match result {
571 None => v,
572 Some(cur) if is_min => cur.min(v),
573 Some(cur) => cur.max(v),
574 });
575 }
576 }
577 builder.append_option(result);
578 }
579 Ok(Arc::new(builder.finish()))
580 }
581 DataType::Float64 => {
582 let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
583 let mut builder = Float64Builder::with_capacity(num_groups);
584 for key in ordered_keys {
585 let mut result: Option<f64> = None;
586 for &i in &groups[key] {
587 if !arr.is_null(i) {
588 let v = arr.value(i);
589 result = Some(match result {
590 None => v,
591 Some(cur) if is_min => cur.min(v),
592 Some(cur) => cur.max(v),
593 });
594 }
595 }
596 builder.append_option(result);
597 }
598 Ok(Arc::new(builder.finish()))
599 }
600 dt => {
601 let use_large = matches!(dt, DataType::LargeUtf8);
605 let mut values: Vec<Option<String>> = Vec::with_capacity(num_groups);
606 for key in ordered_keys {
607 let indices = &groups[key];
608 let mut result: Option<String> = None;
609 for &i in indices {
610 if col.is_null(i) {
611 continue;
612 }
613 let v = format!("{:?}", scalar_to_value(col, i));
614 result = Some(match result {
615 None => v,
616 Some(cur) if is_min && v < cur => v,
617 Some(cur) if !is_min && v > cur => v,
618 Some(cur) => cur,
619 });
620 }
621 values.push(result);
622 }
623 Ok(build_optional_string_array(&values, use_large))
624 }
625 }
626}
627
628fn build_optional_string_array(
629 values: &[Option<String>],
630 use_large: bool,
631) -> arrow_array::ArrayRef {
632 if use_large {
633 let mut builder = arrow_array::builder::LargeStringBuilder::new();
634 for v in values {
635 match v {
636 Some(s) => builder.append_value(s),
637 None => builder.append_null(),
638 }
639 }
640 Arc::new(builder.finish())
641 } else {
642 let mut builder = arrow_array::builder::StringBuilder::new();
643 for v in values {
644 match v {
645 Some(s) => builder.append_value(s),
646 None => builder.append_null(),
647 }
648 }
649 Arc::new(builder.finish())
650 }
651}
652
653fn scalar_to_value(col: &dyn Array, row_idx: usize) -> uni_common::Value {
654 if col.is_null(row_idx) {
655 return uni_common::Value::Null;
656 }
657 match col.data_type() {
658 DataType::Int64 => {
659 let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
660 uni_common::Value::Int(arr.value(row_idx))
661 }
662 DataType::Float64 => {
663 let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
664 uni_common::Value::Float(arr.value(row_idx))
665 }
666 DataType::Utf8 => {
667 let arr = col
668 .as_any()
669 .downcast_ref::<arrow_array::StringArray>()
670 .unwrap();
671 uni_common::Value::String(arr.value(row_idx).to_string())
672 }
673 DataType::LargeUtf8 => {
674 let arr = col
675 .as_any()
676 .downcast_ref::<arrow_array::LargeStringArray>()
677 .unwrap();
678 uni_common::Value::String(arr.value(row_idx).to_string())
679 }
680 DataType::Boolean => {
681 let arr = col
682 .as_any()
683 .downcast_ref::<arrow_array::BooleanArray>()
684 .unwrap();
685 uni_common::Value::Bool(arr.value(row_idx))
686 }
687 DataType::LargeBinary => {
688 let arr = col
689 .as_any()
690 .downcast_ref::<arrow_array::LargeBinaryArray>()
691 .unwrap();
692 let bytes = arr.value(row_idx);
693 uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null)
694 }
695 _ => uni_common::Value::Null,
696 }
697}
698
699enum FoldStreamState {
704 Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
705 Done,
706}
707
708struct FoldStream {
709 state: FoldStreamState,
710 schema: SchemaRef,
711 metrics: BaselineMetrics,
712}
713
714impl Stream for FoldStream {
715 type Item = DFResult<RecordBatch>;
716
717 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
718 match &mut self.state {
719 FoldStreamState::Running(fut) => match fut.as_mut().poll(cx) {
720 Poll::Ready(Ok(batch)) => {
721 self.metrics.record_output(batch.num_rows());
722 self.state = FoldStreamState::Done;
723 Poll::Ready(Some(Ok(batch)))
724 }
725 Poll::Ready(Err(e)) => {
726 self.state = FoldStreamState::Done;
727 Poll::Ready(Some(Err(e)))
728 }
729 Poll::Pending => Poll::Pending,
730 },
731 FoldStreamState::Done => Poll::Ready(None),
732 }
733 }
734}
735
736impl RecordBatchStream for FoldStream {
737 fn schema(&self) -> SchemaRef {
738 Arc::clone(&self.schema)
739 }
740}
741
742#[cfg(test)]
743mod tests {
744 use super::*;
745 use arrow_array::{Float64Array, Int64Array, StringArray};
746 use arrow_schema::{DataType, Field, Schema};
747 use datafusion::physical_plan::memory::MemoryStream;
748 use datafusion::prelude::SessionContext;
749
750 fn make_test_batch(names: Vec<&str>, values: Vec<f64>) -> RecordBatch {
751 let schema = Arc::new(Schema::new(vec![
752 Field::new("name", DataType::Utf8, true),
753 Field::new("value", DataType::Float64, true),
754 ]));
755 RecordBatch::try_new(
756 schema,
757 vec![
758 Arc::new(StringArray::from(
759 names.into_iter().map(Some).collect::<Vec<_>>(),
760 )),
761 Arc::new(Float64Array::from(values)),
762 ],
763 )
764 .unwrap()
765 }
766
767 fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
768 let schema = batch.schema();
769 Arc::new(TestMemoryExec {
770 batches: vec![batch],
771 schema: schema.clone(),
772 properties: compute_plan_properties(schema),
773 })
774 }
775
776 #[derive(Debug)]
777 struct TestMemoryExec {
778 batches: Vec<RecordBatch>,
779 schema: SchemaRef,
780 properties: PlanProperties,
781 }
782
783 impl DisplayAs for TestMemoryExec {
784 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
785 write!(f, "TestMemoryExec")
786 }
787 }
788
789 impl ExecutionPlan for TestMemoryExec {
790 fn name(&self) -> &str {
791 "TestMemoryExec"
792 }
793 fn as_any(&self) -> &dyn Any {
794 self
795 }
796 fn schema(&self) -> SchemaRef {
797 Arc::clone(&self.schema)
798 }
799 fn properties(&self) -> &PlanProperties {
800 &self.properties
801 }
802 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
803 vec![]
804 }
805 fn with_new_children(
806 self: Arc<Self>,
807 _children: Vec<Arc<dyn ExecutionPlan>>,
808 ) -> DFResult<Arc<dyn ExecutionPlan>> {
809 Ok(self)
810 }
811 fn execute(
812 &self,
813 _partition: usize,
814 _context: Arc<TaskContext>,
815 ) -> DFResult<SendableRecordBatchStream> {
816 Ok(Box::pin(MemoryStream::try_new(
817 self.batches.clone(),
818 Arc::clone(&self.schema),
819 None,
820 )?))
821 }
822 }
823
824 async fn execute_fold(
825 input: Arc<dyn ExecutionPlan>,
826 key_indices: Vec<usize>,
827 fold_bindings: Vec<FoldBinding>,
828 ) -> RecordBatch {
829 let exec = FoldExec::new(input, key_indices, fold_bindings, false, 1e-15);
830 let ctx = SessionContext::new();
831 let task_ctx = ctx.task_ctx();
832 let stream = exec.execute(0, task_ctx).unwrap();
833 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
834 .await
835 .unwrap();
836 if batches.is_empty() {
837 RecordBatch::new_empty(exec.schema())
838 } else {
839 arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
840 }
841 }
842
843 #[tokio::test]
844 async fn test_sum_single_group() {
845 let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
846 let input = make_memory_exec(batch);
847 let result = execute_fold(
848 input,
849 vec![0],
850 vec![FoldBinding {
851 output_name: "total".to_string(),
852 kind: FoldAggKind::Sum,
853 input_col_index: 1,
854 input_col_name: None,
855 }],
856 )
857 .await;
858
859 assert_eq!(result.num_rows(), 1);
860 let totals = result
861 .column(1)
862 .as_any()
863 .downcast_ref::<Float64Array>()
864 .unwrap();
865 assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
866 }
867
868 #[tokio::test]
869 async fn test_count_non_null() {
870 let schema = Arc::new(Schema::new(vec![
871 Field::new("name", DataType::Utf8, true),
872 Field::new("value", DataType::Float64, true),
873 ]));
874 let batch = RecordBatch::try_new(
875 schema,
876 vec![
877 Arc::new(StringArray::from(vec![Some("a"), Some("a"), Some("a")])),
878 Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])),
879 ],
880 )
881 .unwrap();
882 let input = make_memory_exec(batch);
883 let result = execute_fold(
884 input,
885 vec![0],
886 vec![FoldBinding {
887 output_name: "cnt".to_string(),
888 kind: FoldAggKind::Count,
889 input_col_index: 1,
890 input_col_name: None,
891 }],
892 )
893 .await;
894
895 assert_eq!(result.num_rows(), 1);
896 let counts = result
897 .column(1)
898 .as_any()
899 .downcast_ref::<Int64Array>()
900 .unwrap();
901 assert_eq!(counts.value(0), 2); }
903
904 #[tokio::test]
905 async fn test_max_min() {
906 let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 5.0]);
907 let input_max = make_memory_exec(batch.clone());
908 let input_min = make_memory_exec(batch);
909
910 let result_max = execute_fold(
911 input_max,
912 vec![0],
913 vec![FoldBinding {
914 output_name: "mx".to_string(),
915 kind: FoldAggKind::Max,
916 input_col_index: 1,
917 input_col_name: None,
918 }],
919 )
920 .await;
921 let result_min = execute_fold(
922 input_min,
923 vec![0],
924 vec![FoldBinding {
925 output_name: "mn".to_string(),
926 kind: FoldAggKind::Min,
927 input_col_index: 1,
928 input_col_name: None,
929 }],
930 )
931 .await;
932
933 let max_vals = result_max
934 .column(1)
935 .as_any()
936 .downcast_ref::<Float64Array>()
937 .unwrap();
938 assert_eq!(max_vals.value(0), 5.0);
939
940 let min_vals = result_min
941 .column(1)
942 .as_any()
943 .downcast_ref::<Float64Array>()
944 .unwrap();
945 assert_eq!(min_vals.value(0), 1.0);
946 }
947
948 #[tokio::test]
949 async fn test_avg() {
950 let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![2.0, 4.0, 6.0, 8.0]);
951 let input = make_memory_exec(batch);
952 let result = execute_fold(
953 input,
954 vec![0],
955 vec![FoldBinding {
956 output_name: "average".to_string(),
957 kind: FoldAggKind::Avg,
958 input_col_index: 1,
959 input_col_name: None,
960 }],
961 )
962 .await;
963
964 assert_eq!(result.num_rows(), 1);
965 let avgs = result
966 .column(1)
967 .as_any()
968 .downcast_ref::<Float64Array>()
969 .unwrap();
970 assert!((avgs.value(0) - 5.0).abs() < f64::EPSILON);
971 }
972
973 #[tokio::test]
974 async fn test_multiple_groups() {
975 let batch = make_test_batch(
976 vec!["a", "a", "b", "b", "b"],
977 vec![1.0, 2.0, 10.0, 20.0, 30.0],
978 );
979 let input = make_memory_exec(batch);
980 let result = execute_fold(
981 input,
982 vec![0],
983 vec![FoldBinding {
984 output_name: "total".to_string(),
985 kind: FoldAggKind::Sum,
986 input_col_index: 1,
987 input_col_name: None,
988 }],
989 )
990 .await;
991
992 assert_eq!(result.num_rows(), 2);
993 let names = result
994 .column(0)
995 .as_any()
996 .downcast_ref::<StringArray>()
997 .unwrap();
998 let totals = result
999 .column(1)
1000 .as_any()
1001 .downcast_ref::<Float64Array>()
1002 .unwrap();
1003
1004 for i in 0..2 {
1005 match names.value(i) {
1006 "a" => assert!((totals.value(i) - 3.0).abs() < f64::EPSILON),
1007 "b" => assert!((totals.value(i) - 60.0).abs() < f64::EPSILON),
1008 _ => panic!("unexpected name"),
1009 }
1010 }
1011 }
1012
1013 #[tokio::test]
1014 async fn test_empty_input() {
1015 let schema = Arc::new(Schema::new(vec![
1016 Field::new("name", DataType::Utf8, true),
1017 Field::new("value", DataType::Float64, true),
1018 ]));
1019 let batch = RecordBatch::new_empty(schema);
1020 let input = make_memory_exec(batch);
1021 let result = execute_fold(
1022 input,
1023 vec![0],
1024 vec![FoldBinding {
1025 output_name: "total".to_string(),
1026 kind: FoldAggKind::Sum,
1027 input_col_index: 1,
1028 input_col_name: None,
1029 }],
1030 )
1031 .await;
1032
1033 assert_eq!(result.num_rows(), 0);
1034 }
1035
1036 #[tokio::test]
1037 async fn test_multiple_bindings() {
1038 let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
1039 let input = make_memory_exec(batch);
1040 let result = execute_fold(
1041 input,
1042 vec![0],
1043 vec![
1044 FoldBinding {
1045 output_name: "total".to_string(),
1046 kind: FoldAggKind::Sum,
1047 input_col_index: 1,
1048 input_col_name: None,
1049 },
1050 FoldBinding {
1051 output_name: "cnt".to_string(),
1052 kind: FoldAggKind::Count,
1053 input_col_index: 1,
1054 input_col_name: None,
1055 },
1056 FoldBinding {
1057 output_name: "mx".to_string(),
1058 kind: FoldAggKind::Max,
1059 input_col_index: 1,
1060 input_col_name: None,
1061 },
1062 ],
1063 )
1064 .await;
1065
1066 assert_eq!(result.num_rows(), 1);
1067 assert_eq!(result.num_columns(), 4); let totals = result
1070 .column(1)
1071 .as_any()
1072 .downcast_ref::<Float64Array>()
1073 .unwrap();
1074 assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
1075
1076 let counts = result
1077 .column(2)
1078 .as_any()
1079 .downcast_ref::<Int64Array>()
1080 .unwrap();
1081 assert_eq!(counts.value(0), 3);
1082
1083 let maxes = result
1084 .column(3)
1085 .as_any()
1086 .downcast_ref::<Float64Array>()
1087 .unwrap();
1088 assert_eq!(maxes.value(0), 3.0);
1089 }
1090
1091 #[tokio::test]
1094 async fn test_nor_single_group() {
1095 let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
1097 let input = make_memory_exec(batch);
1098 let result = execute_fold(
1099 input,
1100 vec![0],
1101 vec![FoldBinding {
1102 output_name: "prob".to_string(),
1103 kind: FoldAggKind::Nor,
1104 input_col_index: 1,
1105 input_col_name: None,
1106 }],
1107 )
1108 .await;
1109
1110 assert_eq!(result.num_rows(), 1);
1111 let vals = result
1112 .column(1)
1113 .as_any()
1114 .downcast_ref::<Float64Array>()
1115 .unwrap();
1116 assert!((vals.value(0) - 0.65).abs() < 1e-10);
1117 }
1118
1119 #[tokio::test]
1120 async fn test_nor_identity() {
1121 let batch = make_test_batch(vec!["a", "a"], vec![0.0, 0.0]);
1123 let input = make_memory_exec(batch);
1124 let result = execute_fold(
1125 input,
1126 vec![0],
1127 vec![FoldBinding {
1128 output_name: "prob".to_string(),
1129 kind: FoldAggKind::Nor,
1130 input_col_index: 1,
1131 input_col_name: None,
1132 }],
1133 )
1134 .await;
1135
1136 let vals = result
1137 .column(1)
1138 .as_any()
1139 .downcast_ref::<Float64Array>()
1140 .unwrap();
1141 assert!((vals.value(0) - 0.0).abs() < 1e-10);
1142 }
1143
1144 #[tokio::test]
1145 async fn test_nor_clamping() {
1146 let batch = make_test_batch(vec!["a", "a"], vec![-0.5, 1.5]);
1148 let input = make_memory_exec(batch);
1149 let result = execute_fold(
1150 input,
1151 vec![0],
1152 vec![FoldBinding {
1153 output_name: "prob".to_string(),
1154 kind: FoldAggKind::Nor,
1155 input_col_index: 1,
1156 input_col_name: None,
1157 }],
1158 )
1159 .await;
1160
1161 let vals = result
1162 .column(1)
1163 .as_any()
1164 .downcast_ref::<Float64Array>()
1165 .unwrap();
1166 assert!((vals.value(0) - 1.0).abs() < 1e-10);
1168 }
1169
1170 #[tokio::test]
1171 async fn test_nor_multiple_groups() {
1172 let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.3, 0.5, 0.1, 0.2]);
1173 let input = make_memory_exec(batch);
1174 let result = execute_fold(
1175 input,
1176 vec![0],
1177 vec![FoldBinding {
1178 output_name: "prob".to_string(),
1179 kind: FoldAggKind::Nor,
1180 input_col_index: 1,
1181 input_col_name: None,
1182 }],
1183 )
1184 .await;
1185
1186 assert_eq!(result.num_rows(), 2);
1187 let names = result
1188 .column(0)
1189 .as_any()
1190 .downcast_ref::<StringArray>()
1191 .unwrap();
1192 let vals = result
1193 .column(1)
1194 .as_any()
1195 .downcast_ref::<Float64Array>()
1196 .unwrap();
1197
1198 for i in 0..2 {
1199 match names.value(i) {
1200 "a" => assert!((vals.value(i) - 0.65).abs() < 1e-10),
1202 "b" => assert!((vals.value(i) - 0.28).abs() < 1e-10),
1204 _ => panic!("unexpected name"),
1205 }
1206 }
1207 }
1208
1209 #[tokio::test]
1212 async fn test_prod_single_group() {
1213 let batch = make_test_batch(vec!["a", "a"], vec![0.6, 0.8]);
1215 let input = make_memory_exec(batch);
1216 let result = execute_fold(
1217 input,
1218 vec![0],
1219 vec![FoldBinding {
1220 output_name: "prob".to_string(),
1221 kind: FoldAggKind::Prod,
1222 input_col_index: 1,
1223 input_col_name: None,
1224 }],
1225 )
1226 .await;
1227
1228 assert_eq!(result.num_rows(), 1);
1229 let vals = result
1230 .column(1)
1231 .as_any()
1232 .downcast_ref::<Float64Array>()
1233 .unwrap();
1234 assert!((vals.value(0) - 0.48).abs() < 1e-10);
1235 }
1236
1237 #[tokio::test]
1238 async fn test_prod_identity() {
1239 let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0]);
1241 let input = make_memory_exec(batch);
1242 let result = execute_fold(
1243 input,
1244 vec![0],
1245 vec![FoldBinding {
1246 output_name: "prob".to_string(),
1247 kind: FoldAggKind::Prod,
1248 input_col_index: 1,
1249 input_col_name: None,
1250 }],
1251 )
1252 .await;
1253
1254 let vals = result
1255 .column(1)
1256 .as_any()
1257 .downcast_ref::<Float64Array>()
1258 .unwrap();
1259 assert!((vals.value(0) - 1.0).abs() < 1e-10);
1260 }
1261
1262 #[tokio::test]
1263 async fn test_prod_zero_absorbing() {
1264 let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.0, 0.8]);
1266 let input = make_memory_exec(batch);
1267 let result = execute_fold(
1268 input,
1269 vec![0],
1270 vec![FoldBinding {
1271 output_name: "prob".to_string(),
1272 kind: FoldAggKind::Prod,
1273 input_col_index: 1,
1274 input_col_name: None,
1275 }],
1276 )
1277 .await;
1278
1279 let vals = result
1280 .column(1)
1281 .as_any()
1282 .downcast_ref::<Float64Array>()
1283 .unwrap();
1284 assert!((vals.value(0) - 0.0).abs() < 1e-10);
1285 }
1286
1287 #[tokio::test]
1288 async fn test_prod_underflow_protection() {
1289 let names: Vec<&str> = vec!["a"; 50];
1291 let values: Vec<f64> = vec![0.5; 50];
1292 let batch = make_test_batch(names, values);
1293 let input = make_memory_exec(batch);
1294 let result = execute_fold(
1295 input,
1296 vec![0],
1297 vec![FoldBinding {
1298 output_name: "prob".to_string(),
1299 kind: FoldAggKind::Prod,
1300 input_col_index: 1,
1301 input_col_name: None,
1302 }],
1303 )
1304 .await;
1305
1306 let vals = result
1307 .column(1)
1308 .as_any()
1309 .downcast_ref::<Float64Array>()
1310 .unwrap();
1311 let expected = 0.5_f64.powi(50); assert!(vals.value(0) > 0.0, "should not underflow to zero");
1313 assert!(
1314 (vals.value(0) - expected).abs() / expected < 1e-6,
1315 "result {} should be close to expected {}",
1316 vals.value(0),
1317 expected
1318 );
1319 }
1320
1321 fn make_nullable_test_batch(names: Vec<&str>, values: Vec<Option<f64>>) -> RecordBatch {
1324 let schema = Arc::new(Schema::new(vec![
1325 Field::new("name", DataType::Utf8, true),
1326 Field::new("value", DataType::Float64, true),
1327 ]));
1328 RecordBatch::try_new(
1329 schema,
1330 vec![
1331 Arc::new(StringArray::from(
1332 names.into_iter().map(Some).collect::<Vec<_>>(),
1333 )),
1334 Arc::new(Float64Array::from(values)),
1335 ],
1336 )
1337 .unwrap()
1338 }
1339
1340 #[tokio::test]
1341 async fn test_nor_single_element() {
1342 let batch = make_test_batch(vec!["a"], vec![0.7]);
1344 let input = make_memory_exec(batch);
1345 let result = execute_fold(
1346 input,
1347 vec![0],
1348 vec![FoldBinding {
1349 output_name: "prob".to_string(),
1350 kind: FoldAggKind::Nor,
1351 input_col_index: 1,
1352 input_col_name: None,
1353 }],
1354 )
1355 .await;
1356 let vals = result
1357 .column(1)
1358 .as_any()
1359 .downcast_ref::<Float64Array>()
1360 .unwrap();
1361 assert!((vals.value(0) - 0.7).abs() < 1e-10);
1362 }
1363
1364 #[tokio::test]
1365 async fn test_prod_single_element() {
1366 let batch = make_test_batch(vec!["a"], vec![0.7]);
1368 let input = make_memory_exec(batch);
1369 let result = execute_fold(
1370 input,
1371 vec![0],
1372 vec![FoldBinding {
1373 output_name: "prob".to_string(),
1374 kind: FoldAggKind::Prod,
1375 input_col_index: 1,
1376 input_col_name: None,
1377 }],
1378 )
1379 .await;
1380 let vals = result
1381 .column(1)
1382 .as_any()
1383 .downcast_ref::<Float64Array>()
1384 .unwrap();
1385 assert!((vals.value(0) - 0.7).abs() < 1e-10);
1386 }
1387
1388 #[tokio::test]
1389 async fn test_nor_three_elements() {
1390 let batch = make_test_batch(vec!["a", "a", "a"], vec![0.3, 0.4, 0.5]);
1392 let input = make_memory_exec(batch);
1393 let result = execute_fold(
1394 input,
1395 vec![0],
1396 vec![FoldBinding {
1397 output_name: "prob".to_string(),
1398 kind: FoldAggKind::Nor,
1399 input_col_index: 1,
1400 input_col_name: None,
1401 }],
1402 )
1403 .await;
1404 let vals = result
1405 .column(1)
1406 .as_any()
1407 .downcast_ref::<Float64Array>()
1408 .unwrap();
1409 assert!((vals.value(0) - 0.79).abs() < 1e-10);
1410 }
1411
1412 #[tokio::test]
1413 async fn test_nor_four_elements_spec_example() {
1414 let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![0.72, 0.54, 0.56, 0.42]);
1416 let input = make_memory_exec(batch);
1417 let result = execute_fold(
1418 input,
1419 vec![0],
1420 vec![FoldBinding {
1421 output_name: "prob".to_string(),
1422 kind: FoldAggKind::Nor,
1423 input_col_index: 1,
1424 input_col_name: None,
1425 }],
1426 )
1427 .await;
1428 let vals = result
1429 .column(1)
1430 .as_any()
1431 .downcast_ref::<Float64Array>()
1432 .unwrap();
1433 assert!(
1434 (vals.value(0) - 0.96713024).abs() < 1e-10,
1435 "expected 0.96713024, got {}",
1436 vals.value(0)
1437 );
1438 }
1439
1440 #[tokio::test]
1441 async fn test_prod_three_elements() {
1442 let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.5, 0.5]);
1444 let input = make_memory_exec(batch);
1445 let result = execute_fold(
1446 input,
1447 vec![0],
1448 vec![FoldBinding {
1449 output_name: "prob".to_string(),
1450 kind: FoldAggKind::Prod,
1451 input_col_index: 1,
1452 input_col_name: None,
1453 }],
1454 )
1455 .await;
1456 let vals = result
1457 .column(1)
1458 .as_any()
1459 .downcast_ref::<Float64Array>()
1460 .unwrap();
1461 assert!((vals.value(0) - 0.125).abs() < 1e-10);
1462 }
1463
1464 #[tokio::test]
1465 async fn test_nor_absorbing_element() {
1466 let batch = make_test_batch(vec!["a", "a"], vec![0.3, 1.0]);
1468 let input = make_memory_exec(batch);
1469 let result = execute_fold(
1470 input,
1471 vec![0],
1472 vec![FoldBinding {
1473 output_name: "prob".to_string(),
1474 kind: FoldAggKind::Nor,
1475 input_col_index: 1,
1476 input_col_name: None,
1477 }],
1478 )
1479 .await;
1480 let vals = result
1481 .column(1)
1482 .as_any()
1483 .downcast_ref::<Float64Array>()
1484 .unwrap();
1485 assert!((vals.value(0) - 1.0).abs() < 1e-10);
1486 }
1487
1488 #[tokio::test]
1489 async fn test_prod_clamping() {
1490 let batch = make_test_batch(vec!["a", "a"], vec![2.0, 0.5]);
1492 let input = make_memory_exec(batch);
1493 let result = execute_fold(
1494 input,
1495 vec![0],
1496 vec![FoldBinding {
1497 output_name: "prob".to_string(),
1498 kind: FoldAggKind::Prod,
1499 input_col_index: 1,
1500 input_col_name: None,
1501 }],
1502 )
1503 .await;
1504 let vals = result
1505 .column(1)
1506 .as_any()
1507 .downcast_ref::<Float64Array>()
1508 .unwrap();
1509 assert!((vals.value(0) - 0.5).abs() < 1e-10);
1510 }
1511
1512 #[tokio::test]
1513 async fn test_prod_multiple_groups() {
1514 let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.6, 0.8, 0.5, 0.5]);
1516 let input = make_memory_exec(batch);
1517 let result = execute_fold(
1518 input,
1519 vec![0],
1520 vec![FoldBinding {
1521 output_name: "prob".to_string(),
1522 kind: FoldAggKind::Prod,
1523 input_col_index: 1,
1524 input_col_name: None,
1525 }],
1526 )
1527 .await;
1528
1529 assert_eq!(result.num_rows(), 2);
1530 let names = result
1531 .column(0)
1532 .as_any()
1533 .downcast_ref::<StringArray>()
1534 .unwrap();
1535 let vals = result
1536 .column(1)
1537 .as_any()
1538 .downcast_ref::<Float64Array>()
1539 .unwrap();
1540 for i in 0..2 {
1541 match names.value(i) {
1542 "a" => assert!((vals.value(i) - 0.48).abs() < 1e-10),
1543 "b" => assert!((vals.value(i) - 0.25).abs() < 1e-10),
1544 _ => panic!("unexpected group name"),
1545 }
1546 }
1547 }
1548
1549 #[tokio::test]
1550 async fn test_nor_commutativity() {
1551 let fwd = make_test_batch(vec!["a", "a", "a"], vec![0.2, 0.5, 0.8]);
1553 let rev = make_test_batch(vec!["a", "a", "a"], vec![0.8, 0.5, 0.2]);
1554 let binding = vec![FoldBinding {
1555 output_name: "prob".to_string(),
1556 kind: FoldAggKind::Nor,
1557 input_col_index: 1,
1558 input_col_name: None,
1559 }];
1560 let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1561 let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1562 let v1 = r1
1563 .column(1)
1564 .as_any()
1565 .downcast_ref::<Float64Array>()
1566 .unwrap()
1567 .value(0);
1568 let v2 = r2
1569 .column(1)
1570 .as_any()
1571 .downcast_ref::<Float64Array>()
1572 .unwrap()
1573 .value(0);
1574 assert!((v1 - 0.92).abs() < 1e-10);
1575 assert!((v2 - 0.92).abs() < 1e-10);
1576 assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1577 }
1578
1579 #[tokio::test]
1580 async fn test_prod_commutativity() {
1581 let fwd = make_test_batch(vec!["a", "a"], vec![0.5, 0.25]);
1583 let rev = make_test_batch(vec!["a", "a"], vec![0.25, 0.5]);
1584 let binding = vec![FoldBinding {
1585 output_name: "prob".to_string(),
1586 kind: FoldAggKind::Prod,
1587 input_col_index: 1,
1588 input_col_name: None,
1589 }];
1590 let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1591 let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1592 let v1 = r1
1593 .column(1)
1594 .as_any()
1595 .downcast_ref::<Float64Array>()
1596 .unwrap()
1597 .value(0);
1598 let v2 = r2
1599 .column(1)
1600 .as_any()
1601 .downcast_ref::<Float64Array>()
1602 .unwrap()
1603 .value(0);
1604 assert!((v1 - 0.125).abs() < 1e-10);
1605 assert!((v2 - 0.125).abs() < 1e-10);
1606 assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1607 }
1608
1609 #[tokio::test]
1610 async fn test_nor_boundary_near_zero() {
1611 let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1613 let input = make_memory_exec(batch);
1614 let result = execute_fold(
1615 input,
1616 vec![0],
1617 vec![FoldBinding {
1618 output_name: "prob".to_string(),
1619 kind: FoldAggKind::Nor,
1620 input_col_index: 1,
1621 input_col_name: None,
1622 }],
1623 )
1624 .await;
1625 let vals = result
1626 .column(1)
1627 .as_any()
1628 .downcast_ref::<Float64Array>()
1629 .unwrap();
1630 let expected = 1.0 - 0.999 * 0.998;
1631 assert!(
1632 (vals.value(0) - expected).abs() < 1e-10,
1633 "expected {}, got {}",
1634 expected,
1635 vals.value(0)
1636 );
1637 }
1638
1639 #[tokio::test]
1640 async fn test_nor_boundary_near_one() {
1641 let batch = make_test_batch(vec!["a", "a"], vec![0.999, 0.998]);
1643 let input = make_memory_exec(batch);
1644 let result = execute_fold(
1645 input,
1646 vec![0],
1647 vec![FoldBinding {
1648 output_name: "prob".to_string(),
1649 kind: FoldAggKind::Nor,
1650 input_col_index: 1,
1651 input_col_name: None,
1652 }],
1653 )
1654 .await;
1655 let vals = result
1656 .column(1)
1657 .as_any()
1658 .downcast_ref::<Float64Array>()
1659 .unwrap();
1660 let expected = 1.0 - 0.001 * 0.002;
1661 assert!(
1662 (vals.value(0) - expected).abs() < 1e-10,
1663 "expected {}, got {}",
1664 expected,
1665 vals.value(0)
1666 );
1667 }
1668
1669 #[tokio::test]
1670 async fn test_prod_boundary_near_zero() {
1671 let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1673 let input = make_memory_exec(batch);
1674 let result = execute_fold(
1675 input,
1676 vec![0],
1677 vec![FoldBinding {
1678 output_name: "prob".to_string(),
1679 kind: FoldAggKind::Prod,
1680 input_col_index: 1,
1681 input_col_name: None,
1682 }],
1683 )
1684 .await;
1685 let vals = result
1686 .column(1)
1687 .as_any()
1688 .downcast_ref::<Float64Array>()
1689 .unwrap();
1690 assert!(
1691 (vals.value(0) - 2e-6).abs() < 1e-15,
1692 "expected 2e-6, got {}",
1693 vals.value(0)
1694 );
1695 }
1696
1697 #[tokio::test]
1698 async fn test_nor_empty_input() {
1699 let schema = Arc::new(Schema::new(vec![
1701 Field::new("name", DataType::Utf8, true),
1702 Field::new("value", DataType::Float64, true),
1703 ]));
1704 let batch = RecordBatch::new_empty(schema);
1705 let input = make_memory_exec(batch);
1706 let result = execute_fold(
1707 input,
1708 vec![0],
1709 vec![FoldBinding {
1710 output_name: "prob".to_string(),
1711 kind: FoldAggKind::Nor,
1712 input_col_index: 1,
1713 input_col_name: None,
1714 }],
1715 )
1716 .await;
1717 assert_eq!(result.num_rows(), 0);
1718 }
1719
1720 #[tokio::test]
1721 async fn test_nor_nan_handling() {
1722 let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::NAN]);
1724 let input = make_memory_exec(batch);
1725 let result = execute_fold(
1726 input,
1727 vec![0],
1728 vec![FoldBinding {
1729 output_name: "prob".to_string(),
1730 kind: FoldAggKind::Nor,
1731 input_col_index: 1,
1732 input_col_name: None,
1733 }],
1734 )
1735 .await;
1736 let vals = result
1737 .column(1)
1738 .as_any()
1739 .downcast_ref::<Float64Array>()
1740 .unwrap();
1741 assert!(vals.value(0).is_nan(), "NaN should propagate through MNOR");
1742 }
1743
1744 #[tokio::test]
1745 async fn test_prod_nan_handling() {
1746 let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::NAN]);
1748 let input = make_memory_exec(batch);
1749 let result = execute_fold(
1750 input,
1751 vec![0],
1752 vec![FoldBinding {
1753 output_name: "prob".to_string(),
1754 kind: FoldAggKind::Prod,
1755 input_col_index: 1,
1756 input_col_name: None,
1757 }],
1758 )
1759 .await;
1760 let vals = result
1761 .column(1)
1762 .as_any()
1763 .downcast_ref::<Float64Array>()
1764 .unwrap();
1765 assert!(vals.value(0).is_nan(), "NaN should propagate through MPROD");
1766 }
1767
1768 #[tokio::test]
1769 async fn test_prod_infinity_handling() {
1770 let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::INFINITY]);
1772 let input = make_memory_exec(batch);
1773 let result = execute_fold(
1774 input,
1775 vec![0],
1776 vec![FoldBinding {
1777 output_name: "prob".to_string(),
1778 kind: FoldAggKind::Prod,
1779 input_col_index: 1,
1780 input_col_name: None,
1781 }],
1782 )
1783 .await;
1784 let vals = result
1785 .column(1)
1786 .as_any()
1787 .downcast_ref::<Float64Array>()
1788 .unwrap();
1789 assert!((vals.value(0) - 0.5).abs() < 1e-10);
1790 }
1791
1792 #[tokio::test]
1793 async fn test_nor_infinity_handling() {
1794 let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::INFINITY]);
1796 let input = make_memory_exec(batch);
1797 let result = execute_fold(
1798 input,
1799 vec![0],
1800 vec![FoldBinding {
1801 output_name: "prob".to_string(),
1802 kind: FoldAggKind::Nor,
1803 input_col_index: 1,
1804 input_col_name: None,
1805 }],
1806 )
1807 .await;
1808 let vals = result
1809 .column(1)
1810 .as_any()
1811 .downcast_ref::<Float64Array>()
1812 .unwrap();
1813 assert!((vals.value(0) - 1.0).abs() < 1e-10);
1814 }
1815
1816 #[tokio::test]
1817 async fn test_nor_all_null_values() {
1818 let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1820 let input = make_memory_exec(batch);
1821 let result = execute_fold(
1822 input,
1823 vec![0],
1824 vec![FoldBinding {
1825 output_name: "prob".to_string(),
1826 kind: FoldAggKind::Nor,
1827 input_col_index: 1,
1828 input_col_name: None,
1829 }],
1830 )
1831 .await;
1832 assert_eq!(result.num_rows(), 1);
1833 let vals = result
1834 .column(1)
1835 .as_any()
1836 .downcast_ref::<Float64Array>()
1837 .unwrap();
1838 assert!(vals.is_null(0), "all-null MNOR should produce null");
1839 }
1840
1841 #[tokio::test]
1842 async fn test_prod_all_null_values() {
1843 let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1845 let input = make_memory_exec(batch);
1846 let result = execute_fold(
1847 input,
1848 vec![0],
1849 vec![FoldBinding {
1850 output_name: "prob".to_string(),
1851 kind: FoldAggKind::Prod,
1852 input_col_index: 1,
1853 input_col_name: None,
1854 }],
1855 )
1856 .await;
1857 assert_eq!(result.num_rows(), 1);
1858 let vals = result
1859 .column(1)
1860 .as_any()
1861 .downcast_ref::<Float64Array>()
1862 .unwrap();
1863 assert!(vals.is_null(0), "all-null MPROD should produce null");
1864 }
1865
1866 #[tokio::test]
1867 async fn test_nor_mixed_null_values() {
1868 let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.3), None, Some(0.5)]);
1870 let input = make_memory_exec(batch);
1871 let result = execute_fold(
1872 input,
1873 vec![0],
1874 vec![FoldBinding {
1875 output_name: "prob".to_string(),
1876 kind: FoldAggKind::Nor,
1877 input_col_index: 1,
1878 input_col_name: None,
1879 }],
1880 )
1881 .await;
1882 let vals = result
1883 .column(1)
1884 .as_any()
1885 .downcast_ref::<Float64Array>()
1886 .unwrap();
1887 assert!((vals.value(0) - 0.65).abs() < 1e-10);
1888 }
1889
1890 #[tokio::test]
1891 async fn test_prod_mixed_null_values() {
1892 let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.6), None, Some(0.8)]);
1894 let input = make_memory_exec(batch);
1895 let result = execute_fold(
1896 input,
1897 vec![0],
1898 vec![FoldBinding {
1899 output_name: "prob".to_string(),
1900 kind: FoldAggKind::Prod,
1901 input_col_index: 1,
1902 input_col_name: None,
1903 }],
1904 )
1905 .await;
1906 let vals = result
1907 .column(1)
1908 .as_any()
1909 .downcast_ref::<Float64Array>()
1910 .unwrap();
1911 assert!((vals.value(0) - 0.48).abs() < 1e-10);
1912 }
1913
1914 #[tokio::test]
1915 async fn test_nor_many_small_values() {
1916 let names: Vec<&str> = vec!["a"; 20];
1918 let values: Vec<f64> = vec![0.1; 20];
1919 let batch = make_test_batch(names, values);
1920 let input = make_memory_exec(batch);
1921 let result = execute_fold(
1922 input,
1923 vec![0],
1924 vec![FoldBinding {
1925 output_name: "prob".to_string(),
1926 kind: FoldAggKind::Nor,
1927 input_col_index: 1,
1928 input_col_name: None,
1929 }],
1930 )
1931 .await;
1932 let vals = result
1933 .column(1)
1934 .as_any()
1935 .downcast_ref::<Float64Array>()
1936 .unwrap();
1937 let expected = 1.0 - 0.9_f64.powi(20);
1938 assert!(
1939 (vals.value(0) - expected).abs() < 1e-10,
1940 "expected {}, got {}",
1941 expected,
1942 vals.value(0)
1943 );
1944 }
1945
1946 #[test]
1949 fn test_is_monotonic() {
1950 assert!(FoldAggKind::Sum.is_monotonic());
1951 assert!(FoldAggKind::Max.is_monotonic());
1952 assert!(FoldAggKind::Min.is_monotonic());
1953 assert!(FoldAggKind::Count.is_monotonic());
1954 assert!(FoldAggKind::Nor.is_monotonic());
1955 assert!(FoldAggKind::Prod.is_monotonic());
1956 assert!(!FoldAggKind::Avg.is_monotonic());
1957 assert!(!FoldAggKind::Collect.is_monotonic());
1958 }
1959
1960 #[test]
1961 fn test_monotonicity_direction() {
1962 use super::MonotonicDirection;
1963 assert_eq!(
1964 FoldAggKind::Sum.monotonicity_direction(),
1965 Some(MonotonicDirection::NonDecreasing)
1966 );
1967 assert_eq!(
1968 FoldAggKind::Max.monotonicity_direction(),
1969 Some(MonotonicDirection::NonDecreasing)
1970 );
1971 assert_eq!(
1972 FoldAggKind::Count.monotonicity_direction(),
1973 Some(MonotonicDirection::NonDecreasing)
1974 );
1975 assert_eq!(
1976 FoldAggKind::Nor.monotonicity_direction(),
1977 Some(MonotonicDirection::NonDecreasing)
1978 );
1979 assert_eq!(
1980 FoldAggKind::Min.monotonicity_direction(),
1981 Some(MonotonicDirection::NonIncreasing)
1982 );
1983 assert_eq!(
1984 FoldAggKind::Prod.monotonicity_direction(),
1985 Some(MonotonicDirection::NonIncreasing)
1986 );
1987 assert_eq!(FoldAggKind::Avg.monotonicity_direction(), None);
1988 assert_eq!(FoldAggKind::Collect.monotonicity_direction(), None);
1989 }
1990
1991 #[test]
1992 fn test_identity_values() {
1993 assert_eq!(FoldAggKind::Sum.identity(), Some(0.0));
1994 assert_eq!(FoldAggKind::Count.identity(), Some(0.0));
1995 assert_eq!(FoldAggKind::Nor.identity(), Some(0.0));
1996 assert_eq!(FoldAggKind::Max.identity(), Some(f64::NEG_INFINITY));
1997 assert_eq!(FoldAggKind::Min.identity(), Some(f64::INFINITY));
1998 assert_eq!(FoldAggKind::Prod.identity(), Some(1.0));
1999 assert_eq!(FoldAggKind::Avg.identity(), None);
2000 assert_eq!(FoldAggKind::Collect.identity(), None);
2001 }
2002
2003 async fn execute_fold_strict(
2006 input: Arc<dyn ExecutionPlan>,
2007 key_indices: Vec<usize>,
2008 fold_bindings: Vec<FoldBinding>,
2009 strict: bool,
2010 ) -> DFResult<RecordBatch> {
2011 let exec = FoldExec::new(input, key_indices, fold_bindings, strict, 1e-15);
2012 let ctx = SessionContext::new();
2013 let task_ctx = ctx.task_ctx();
2014 let stream = exec.execute(0, task_ctx).unwrap();
2015 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream).await?;
2016 if batches.is_empty() {
2017 Ok(RecordBatch::new_empty(exec.schema()))
2018 } else {
2019 arrow::compute::concat_batches(&exec.schema(), &batches).map_err(arrow_err)
2020 }
2021 }
2022
2023 #[tokio::test]
2024 async fn test_nor_strict_rejects_above_one() {
2025 let batch = make_test_batch(vec!["a"], vec![1.5]);
2026 let input = make_memory_exec(batch);
2027 let result = execute_fold_strict(
2028 input,
2029 vec![0],
2030 vec![FoldBinding {
2031 output_name: "p".into(),
2032 kind: FoldAggKind::Nor,
2033 input_col_index: 1,
2034 input_col_name: None,
2035 }],
2036 true,
2037 )
2038 .await;
2039 assert!(result.is_err());
2040 let err = result.unwrap_err().to_string();
2041 assert!(
2042 err.contains("strict_probability_domain"),
2043 "Expected strict error, got: {}",
2044 err
2045 );
2046 }
2047
2048 #[tokio::test]
2049 async fn test_nor_strict_rejects_negative() {
2050 let batch = make_test_batch(vec!["a"], vec![-0.1]);
2051 let input = make_memory_exec(batch);
2052 let result = execute_fold_strict(
2053 input,
2054 vec![0],
2055 vec![FoldBinding {
2056 output_name: "p".into(),
2057 kind: FoldAggKind::Nor,
2058 input_col_index: 1,
2059 input_col_name: None,
2060 }],
2061 true,
2062 )
2063 .await;
2064 assert!(result.is_err());
2065 let err = result.unwrap_err().to_string();
2066 assert!(
2067 err.contains("strict_probability_domain"),
2068 "Expected strict error, got: {}",
2069 err
2070 );
2071 }
2072
2073 #[tokio::test]
2074 async fn test_prod_strict_rejects_above_one() {
2075 let batch = make_test_batch(vec!["a"], vec![2.0]);
2076 let input = make_memory_exec(batch);
2077 let result = execute_fold_strict(
2078 input,
2079 vec![0],
2080 vec![FoldBinding {
2081 output_name: "p".into(),
2082 kind: FoldAggKind::Prod,
2083 input_col_index: 1,
2084 input_col_name: None,
2085 }],
2086 true,
2087 )
2088 .await;
2089 assert!(result.is_err());
2090 let err = result.unwrap_err().to_string();
2091 assert!(
2092 err.contains("strict_probability_domain"),
2093 "Expected strict error, got: {}",
2094 err
2095 );
2096 }
2097
2098 #[tokio::test]
2099 async fn test_prod_strict_rejects_negative() {
2100 let batch = make_test_batch(vec!["a"], vec![-0.5]);
2101 let input = make_memory_exec(batch);
2102 let result = execute_fold_strict(
2103 input,
2104 vec![0],
2105 vec![FoldBinding {
2106 output_name: "p".into(),
2107 kind: FoldAggKind::Prod,
2108 input_col_index: 1,
2109 input_col_name: None,
2110 }],
2111 true,
2112 )
2113 .await;
2114 assert!(result.is_err());
2115 let err = result.unwrap_err().to_string();
2116 assert!(
2117 err.contains("strict_probability_domain"),
2118 "Expected strict error, got: {}",
2119 err
2120 );
2121 }
2122
2123 #[tokio::test]
2124 async fn test_nor_strict_accepts_valid() {
2125 let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
2126 let input = make_memory_exec(batch);
2127 let result = execute_fold_strict(
2128 input,
2129 vec![0],
2130 vec![FoldBinding {
2131 output_name: "p".into(),
2132 kind: FoldAggKind::Nor,
2133 input_col_index: 1,
2134 input_col_name: None,
2135 }],
2136 true,
2137 )
2138 .await;
2139 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
2140 let batch = result.unwrap();
2141 let vals = batch
2142 .column(1)
2143 .as_any()
2144 .downcast_ref::<Float64Array>()
2145 .unwrap();
2146 let expected = 0.65; assert!(
2148 (vals.value(0) - expected).abs() < 1e-10,
2149 "expected {}, got {}",
2150 expected,
2151 vals.value(0)
2152 );
2153 }
2154
2155 #[tokio::test]
2156 async fn test_count_all_groups_by_key() {
2157 let batch = make_test_batch(vec!["a", "a", "b"], vec![10.0, 20.0, 30.0]);
2159 let input = make_memory_exec(batch);
2160 let result = execute_fold(
2161 input,
2162 vec![0],
2163 vec![FoldBinding {
2164 output_name: "cnt".to_string(),
2165 kind: FoldAggKind::CountAll,
2166 input_col_index: 0, input_col_name: None,
2168 }],
2169 )
2170 .await;
2171
2172 assert_eq!(result.num_rows(), 2, "Should have 2 groups");
2173 let counts = result
2174 .column(1)
2175 .as_any()
2176 .downcast_ref::<Int64Array>()
2177 .unwrap();
2178 assert_eq!(counts.value(0), 2, "Group 'a' should have count 2");
2179 assert_eq!(counts.value(1), 1, "Group 'b' should have count 1");
2180 }
2181}