1use crate::query::df_graph::common::{
10 ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
11};
12use arrow_array::builder::{Float64Builder, Int64Builder, LargeBinaryBuilder};
13use arrow_array::{Array, Float64Array, Int64Array, RecordBatch};
14use arrow_schema::{DataType, Field, Schema, SchemaRef};
15use datafusion::common::Result as DFResult;
16use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
17use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
18use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
19use futures::{Stream, TryStreamExt};
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MonotonicDirection {
30 NonDecreasing,
32 NonIncreasing,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum FoldAggKind {
39 Sum,
40 Max,
41 Min,
42 Count,
43 CountAll,
45 Avg,
46 Collect,
47 Nor, Prod, }
50
51impl FoldAggKind {
52 pub fn is_monotonic(&self) -> bool {
54 matches!(
55 self,
56 Self::Sum
57 | Self::Max
58 | Self::Min
59 | Self::Count
60 | Self::CountAll
61 | Self::Nor
62 | Self::Prod
63 )
64 }
65
66 pub fn monotonicity_direction(&self) -> Option<MonotonicDirection> {
68 match self {
69 Self::Sum | Self::Max | Self::Count | Self::CountAll | Self::Nor => {
70 Some(MonotonicDirection::NonDecreasing)
71 }
72 Self::Min | Self::Prod => Some(MonotonicDirection::NonIncreasing),
73 Self::Avg | Self::Collect => None,
74 }
75 }
76
77 pub fn identity(&self) -> Option<f64> {
79 match self {
80 Self::Sum | Self::Count | Self::CountAll | Self::Nor => Some(0.0),
81 Self::Max => Some(f64::NEG_INFINITY),
82 Self::Min => Some(f64::INFINITY),
83 Self::Prod => Some(1.0),
84 Self::Avg | Self::Collect => None,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct FoldBinding {
92 pub output_name: String,
93 pub kind: FoldAggKind,
94 pub input_col_index: usize,
95}
96
97#[derive(Debug)]
102pub struct FoldExec {
103 input: Arc<dyn ExecutionPlan>,
104 key_indices: Vec<usize>,
105 fold_bindings: Vec<FoldBinding>,
106 strict_probability_domain: bool,
107 probability_epsilon: f64,
108 schema: SchemaRef,
109 properties: PlanProperties,
110 metrics: ExecutionPlanMetricsSet,
111}
112
113impl FoldExec {
114 pub fn new(
121 input: Arc<dyn ExecutionPlan>,
122 key_indices: Vec<usize>,
123 fold_bindings: Vec<FoldBinding>,
124 strict_probability_domain: bool,
125 probability_epsilon: f64,
126 ) -> Self {
127 let input_schema = input.schema();
128 let schema = Self::build_output_schema(&input_schema, &key_indices, &fold_bindings);
129 let properties = compute_plan_properties(Arc::clone(&schema));
130
131 Self {
132 input,
133 key_indices,
134 fold_bindings,
135 strict_probability_domain,
136 probability_epsilon,
137 schema,
138 properties,
139 metrics: ExecutionPlanMetricsSet::new(),
140 }
141 }
142
143 fn build_output_schema(
144 input_schema: &SchemaRef,
145 key_indices: &[usize],
146 fold_bindings: &[FoldBinding],
147 ) -> SchemaRef {
148 let mut fields = Vec::new();
149
150 for &ki in key_indices {
152 fields.push(Arc::new(input_schema.field(ki).clone()));
153 }
154
155 for binding in fold_bindings {
157 let output_type = match binding.kind {
158 FoldAggKind::Sum | FoldAggKind::Avg | FoldAggKind::Nor | FoldAggKind::Prod => {
159 DataType::Float64
160 }
161 FoldAggKind::Count | FoldAggKind::CountAll => DataType::Int64,
162 FoldAggKind::Max | FoldAggKind::Min => input_schema
163 .field(binding.input_col_index)
164 .data_type()
165 .clone(),
166 FoldAggKind::Collect => DataType::LargeBinary,
167 };
168 fields.push(Arc::new(Field::new(
169 &binding.output_name,
170 output_type,
171 true,
172 )));
173 }
174
175 Arc::new(Schema::new(fields))
176 }
177}
178
179impl DisplayAs for FoldExec {
180 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 write!(
182 f,
183 "FoldExec: key_indices={:?}, bindings={:?}",
184 self.key_indices, self.fold_bindings
185 )
186 }
187}
188
189impl ExecutionPlan for FoldExec {
190 fn name(&self) -> &str {
191 "FoldExec"
192 }
193
194 fn as_any(&self) -> &dyn Any {
195 self
196 }
197
198 fn schema(&self) -> SchemaRef {
199 Arc::clone(&self.schema)
200 }
201
202 fn properties(&self) -> &PlanProperties {
203 &self.properties
204 }
205
206 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
207 vec![&self.input]
208 }
209
210 fn with_new_children(
211 self: Arc<Self>,
212 children: Vec<Arc<dyn ExecutionPlan>>,
213 ) -> DFResult<Arc<dyn ExecutionPlan>> {
214 if children.len() != 1 {
215 return Err(datafusion::error::DataFusionError::Plan(
216 "FoldExec requires exactly one child".to_string(),
217 ));
218 }
219 Ok(Arc::new(Self::new(
220 Arc::clone(&children[0]),
221 self.key_indices.clone(),
222 self.fold_bindings.clone(),
223 self.strict_probability_domain,
224 self.probability_epsilon,
225 )))
226 }
227
228 fn execute(
229 &self,
230 partition: usize,
231 context: Arc<TaskContext>,
232 ) -> DFResult<SendableRecordBatchStream> {
233 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
234 let metrics = BaselineMetrics::new(&self.metrics, partition);
235 let key_indices = self.key_indices.clone();
236 let fold_bindings = self.fold_bindings.clone();
237 let strict = self.strict_probability_domain;
238 let epsilon = self.probability_epsilon;
239 let output_schema = Arc::clone(&self.schema);
240 let input_schema = self.input.schema();
241
242 let fut = async move {
243 let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
244
245 if batches.is_empty() {
246 return Ok(RecordBatch::new_empty(output_schema));
247 }
248
249 let batch =
250 arrow::compute::concat_batches(&input_schema, &batches).map_err(arrow_err)?;
251
252 if batch.num_rows() == 0 {
253 return Ok(RecordBatch::new_empty(output_schema));
254 }
255
256 let mut groups: HashMap<Vec<ScalarKey>, Vec<usize>> = HashMap::new();
258 let mut ordered_keys: Vec<Vec<ScalarKey>> = Vec::new();
259 for row_idx in 0..batch.num_rows() {
260 let key = extract_scalar_key(&batch, &key_indices, row_idx);
261 let entry = groups.entry(key.clone());
262 if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
263 ordered_keys.push(key);
264 }
265 entry.or_default().push(row_idx);
266 }
267
268 let num_groups = ordered_keys.len();
269
270 let mut output_columns: Vec<arrow_array::ArrayRef> = Vec::new();
272
273 for &ki in &key_indices {
275 let col = batch.column(ki);
276 let first_indices: Vec<u32> =
277 ordered_keys.iter().map(|k| groups[k][0] as u32).collect();
278 let idx_array = arrow_array::UInt32Array::from(first_indices);
279 let taken = arrow::compute::take(col.as_ref(), &idx_array, None).map_err(|e| {
280 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
281 })?;
282 output_columns.push(taken);
283 }
284
285 for binding in &fold_bindings {
287 let col: Arc<dyn Array> = if binding.kind == FoldAggKind::CountAll {
288 Arc::new(arrow_array::Int64Array::from(vec![0i64; batch.num_rows()]))
290 } else {
291 Arc::clone(batch.column(binding.input_col_index))
292 };
293 let agg_col = compute_fold_aggregate(
294 col.as_ref(),
295 &binding.kind,
296 &ordered_keys,
297 &groups,
298 num_groups,
299 strict,
300 epsilon,
301 )?;
302 output_columns.push(agg_col);
303 }
304
305 RecordBatch::try_new(output_schema, output_columns).map_err(arrow_err)
306 };
307
308 Ok(Box::pin(FoldStream {
309 state: FoldStreamState::Running(Box::pin(fut)),
310 schema: Arc::clone(&self.schema),
311 metrics,
312 }))
313 }
314
315 fn metrics(&self) -> Option<MetricsSet> {
316 Some(self.metrics.clone_inner())
317 }
318}
319
320fn compute_fold_aggregate(
325 col: &dyn Array,
326 kind: &FoldAggKind,
327 ordered_keys: &[Vec<ScalarKey>],
328 groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
329 num_groups: usize,
330 strict: bool,
331 probability_epsilon: f64,
332) -> DFResult<arrow_array::ArrayRef> {
333 match kind {
334 FoldAggKind::Sum => {
335 let mut builder = Float64Builder::with_capacity(num_groups);
336 for key in ordered_keys {
337 builder.append_option(sum_f64(col, &groups[key]));
338 }
339 Ok(Arc::new(builder.finish()))
340 }
341 FoldAggKind::Count => {
342 let mut builder = Int64Builder::with_capacity(num_groups);
343 for key in ordered_keys {
344 let indices = &groups[key];
345 let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
346 builder.append_value(count as i64);
347 }
348 Ok(Arc::new(builder.finish()))
349 }
350 FoldAggKind::CountAll => {
351 let mut builder = Int64Builder::with_capacity(num_groups);
352 for key in ordered_keys {
353 let indices = &groups[key];
354 builder.append_value(indices.len() as i64);
355 }
356 Ok(Arc::new(builder.finish()))
357 }
358 FoldAggKind::Max => compute_minmax(col, ordered_keys, groups, num_groups, false),
359 FoldAggKind::Min => compute_minmax(col, ordered_keys, groups, num_groups, true),
360 FoldAggKind::Avg => {
361 let mut builder = Float64Builder::with_capacity(num_groups);
362 for key in ordered_keys {
363 let indices = &groups[key];
364 let count = indices.iter().filter(|&&i| !col.is_null(i)).count();
365 let avg = sum_f64(col, indices)
366 .filter(|_| count > 0)
367 .map(|s| s / count as f64);
368 builder.append_option(avg);
369 }
370 Ok(Arc::new(builder.finish()))
371 }
372 FoldAggKind::Collect => {
373 let mut builder = LargeBinaryBuilder::with_capacity(num_groups, num_groups * 32);
374 for key in ordered_keys {
375 let values: Vec<uni_common::Value> = groups[key]
376 .iter()
377 .filter(|&&i| !col.is_null(i))
378 .map(|&i| scalar_to_value(col, i))
379 .collect();
380 let encoded =
381 uni_common::cypher_value_codec::encode(&uni_common::Value::List(values));
382 builder.append_value(&encoded);
383 }
384 Ok(Arc::new(builder.finish()))
385 }
386 FoldAggKind::Nor => {
387 let mut builder = Float64Builder::with_capacity(num_groups);
388 for key in ordered_keys {
389 let indices = &groups[key];
390 builder.append_option(noisy_or_f64(col, indices, strict)?);
391 }
392 Ok(Arc::new(builder.finish()))
393 }
394 FoldAggKind::Prod => {
395 let mut builder = Float64Builder::with_capacity(num_groups);
396 for key in ordered_keys {
397 builder.append_option(product_f64(col, &groups[key], strict, probability_epsilon)?);
398 }
399 Ok(Arc::new(builder.finish()))
400 }
401 }
402}
403
404fn sum_f64(col: &dyn Array, indices: &[usize]) -> Option<f64> {
405 let mut sum = 0.0;
406 let mut has_value = false;
407 for &i in indices {
408 if col.is_null(i) {
409 continue;
410 }
411 has_value = true;
412 if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
413 sum += arr.value(i);
414 } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
415 sum += arr.value(i) as f64;
416 }
417 }
418 if has_value { Some(sum) } else { None }
419}
420
421fn noisy_or_f64(col: &dyn Array, indices: &[usize], strict: bool) -> DFResult<Option<f64>> {
423 let mut complement_product = 1.0;
424 let mut has_value = false;
425 for &i in indices {
426 if col.is_null(i) {
427 continue;
428 }
429 has_value = true;
430 let raw = if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
431 arr.value(i)
432 } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
433 arr.value(i) as f64
434 } else {
435 continue;
436 };
437 if strict && !(0.0..=1.0).contains(&raw) {
438 return Err(datafusion::error::DataFusionError::Execution(format!(
439 "strict_probability_domain: MNOR input {raw} is outside [0, 1]"
440 )));
441 }
442 if !strict && !(0.0..=1.0).contains(&raw) {
443 tracing::warn!(
444 "MNOR input {raw} outside [0,1], clamped to {}",
445 raw.clamp(0.0, 1.0)
446 );
447 }
448 let p = raw.clamp(0.0, 1.0);
449 complement_product *= 1.0 - p;
450 }
451 if has_value {
452 Ok(Some(1.0 - complement_product))
453 } else {
454 Ok(None)
455 }
456}
457
458fn product_f64(
463 col: &dyn Array,
464 indices: &[usize],
465 strict: bool,
466 probability_epsilon: f64,
467) -> DFResult<Option<f64>> {
468 let mut product = 1.0;
469 let mut log_sum = 0.0;
470 let mut use_log = false;
471 let mut has_value = false;
472 for &i in indices {
473 if col.is_null(i) {
474 continue;
475 }
476 has_value = true;
477 let raw = if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
478 arr.value(i)
479 } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
480 arr.value(i) as f64
481 } else {
482 continue;
483 };
484 if strict && !(0.0..=1.0).contains(&raw) {
485 return Err(datafusion::error::DataFusionError::Execution(format!(
486 "strict_probability_domain: MPROD input {raw} is outside [0, 1]"
487 )));
488 }
489 if !strict && !(0.0..=1.0).contains(&raw) {
490 tracing::warn!(
491 "MPROD input {raw} outside [0,1], clamped to {}",
492 raw.clamp(0.0, 1.0)
493 );
494 }
495 let p = raw.clamp(0.0, 1.0);
496 if p == 0.0 {
497 return Ok(Some(0.0));
498 }
499 if use_log {
500 log_sum += p.ln();
501 } else {
502 product *= p;
503 if product < probability_epsilon {
504 log_sum = product.ln();
506 use_log = true;
507 }
508 }
509 }
510 if !has_value {
511 return Ok(None);
512 }
513 if use_log {
514 Ok(Some(log_sum.exp()))
515 } else {
516 Ok(Some(product))
517 }
518}
519
520fn compute_minmax(
521 col: &dyn Array,
522 ordered_keys: &[Vec<ScalarKey>],
523 groups: &HashMap<Vec<ScalarKey>, Vec<usize>>,
524 num_groups: usize,
525 is_min: bool,
526) -> DFResult<arrow_array::ArrayRef> {
527 match col.data_type() {
528 DataType::Int64 => {
529 let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
530 let mut builder = Int64Builder::with_capacity(num_groups);
531 for key in ordered_keys {
532 let mut result: Option<i64> = None;
533 for &i in &groups[key] {
534 if !arr.is_null(i) {
535 let v = arr.value(i);
536 result = Some(match result {
537 None => v,
538 Some(cur) if is_min => cur.min(v),
539 Some(cur) => cur.max(v),
540 });
541 }
542 }
543 builder.append_option(result);
544 }
545 Ok(Arc::new(builder.finish()))
546 }
547 DataType::Float64 => {
548 let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
549 let mut builder = Float64Builder::with_capacity(num_groups);
550 for key in ordered_keys {
551 let mut result: Option<f64> = None;
552 for &i in &groups[key] {
553 if !arr.is_null(i) {
554 let v = arr.value(i);
555 result = Some(match result {
556 None => v,
557 Some(cur) if is_min => cur.min(v),
558 Some(cur) => cur.max(v),
559 });
560 }
561 }
562 builder.append_option(result);
563 }
564 Ok(Arc::new(builder.finish()))
565 }
566 dt => {
567 let use_large = matches!(dt, DataType::LargeUtf8);
571 let mut values: Vec<Option<String>> = Vec::with_capacity(num_groups);
572 for key in ordered_keys {
573 let indices = &groups[key];
574 let mut result: Option<String> = None;
575 for &i in indices {
576 if col.is_null(i) {
577 continue;
578 }
579 let v = format!("{:?}", scalar_to_value(col, i));
580 result = Some(match result {
581 None => v,
582 Some(cur) if is_min && v < cur => v,
583 Some(cur) if !is_min && v > cur => v,
584 Some(cur) => cur,
585 });
586 }
587 values.push(result);
588 }
589 Ok(build_optional_string_array(&values, use_large))
590 }
591 }
592}
593
594fn build_optional_string_array(
595 values: &[Option<String>],
596 use_large: bool,
597) -> arrow_array::ArrayRef {
598 if use_large {
599 let mut builder = arrow_array::builder::LargeStringBuilder::new();
600 for v in values {
601 match v {
602 Some(s) => builder.append_value(s),
603 None => builder.append_null(),
604 }
605 }
606 Arc::new(builder.finish())
607 } else {
608 let mut builder = arrow_array::builder::StringBuilder::new();
609 for v in values {
610 match v {
611 Some(s) => builder.append_value(s),
612 None => builder.append_null(),
613 }
614 }
615 Arc::new(builder.finish())
616 }
617}
618
619fn scalar_to_value(col: &dyn Array, row_idx: usize) -> uni_common::Value {
620 if col.is_null(row_idx) {
621 return uni_common::Value::Null;
622 }
623 match col.data_type() {
624 DataType::Int64 => {
625 let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
626 uni_common::Value::Int(arr.value(row_idx))
627 }
628 DataType::Float64 => {
629 let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
630 uni_common::Value::Float(arr.value(row_idx))
631 }
632 DataType::Utf8 => {
633 let arr = col
634 .as_any()
635 .downcast_ref::<arrow_array::StringArray>()
636 .unwrap();
637 uni_common::Value::String(arr.value(row_idx).to_string())
638 }
639 DataType::LargeUtf8 => {
640 let arr = col
641 .as_any()
642 .downcast_ref::<arrow_array::LargeStringArray>()
643 .unwrap();
644 uni_common::Value::String(arr.value(row_idx).to_string())
645 }
646 DataType::Boolean => {
647 let arr = col
648 .as_any()
649 .downcast_ref::<arrow_array::BooleanArray>()
650 .unwrap();
651 uni_common::Value::Bool(arr.value(row_idx))
652 }
653 DataType::LargeBinary => {
654 let arr = col
655 .as_any()
656 .downcast_ref::<arrow_array::LargeBinaryArray>()
657 .unwrap();
658 let bytes = arr.value(row_idx);
659 uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null)
660 }
661 _ => uni_common::Value::Null,
662 }
663}
664
665enum FoldStreamState {
670 Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
671 Done,
672}
673
674struct FoldStream {
675 state: FoldStreamState,
676 schema: SchemaRef,
677 metrics: BaselineMetrics,
678}
679
680impl Stream for FoldStream {
681 type Item = DFResult<RecordBatch>;
682
683 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
684 match &mut self.state {
685 FoldStreamState::Running(fut) => match fut.as_mut().poll(cx) {
686 Poll::Ready(Ok(batch)) => {
687 self.metrics.record_output(batch.num_rows());
688 self.state = FoldStreamState::Done;
689 Poll::Ready(Some(Ok(batch)))
690 }
691 Poll::Ready(Err(e)) => {
692 self.state = FoldStreamState::Done;
693 Poll::Ready(Some(Err(e)))
694 }
695 Poll::Pending => Poll::Pending,
696 },
697 FoldStreamState::Done => Poll::Ready(None),
698 }
699 }
700}
701
702impl RecordBatchStream for FoldStream {
703 fn schema(&self) -> SchemaRef {
704 Arc::clone(&self.schema)
705 }
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711 use arrow_array::{Float64Array, Int64Array, StringArray};
712 use arrow_schema::{DataType, Field, Schema};
713 use datafusion::physical_plan::memory::MemoryStream;
714 use datafusion::prelude::SessionContext;
715
716 fn make_test_batch(names: Vec<&str>, values: Vec<f64>) -> RecordBatch {
717 let schema = Arc::new(Schema::new(vec![
718 Field::new("name", DataType::Utf8, true),
719 Field::new("value", DataType::Float64, true),
720 ]));
721 RecordBatch::try_new(
722 schema,
723 vec![
724 Arc::new(StringArray::from(
725 names.into_iter().map(Some).collect::<Vec<_>>(),
726 )),
727 Arc::new(Float64Array::from(values)),
728 ],
729 )
730 .unwrap()
731 }
732
733 fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
734 let schema = batch.schema();
735 Arc::new(TestMemoryExec {
736 batches: vec![batch],
737 schema: schema.clone(),
738 properties: compute_plan_properties(schema),
739 })
740 }
741
742 #[derive(Debug)]
743 struct TestMemoryExec {
744 batches: Vec<RecordBatch>,
745 schema: SchemaRef,
746 properties: PlanProperties,
747 }
748
749 impl DisplayAs for TestMemoryExec {
750 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
751 write!(f, "TestMemoryExec")
752 }
753 }
754
755 impl ExecutionPlan for TestMemoryExec {
756 fn name(&self) -> &str {
757 "TestMemoryExec"
758 }
759 fn as_any(&self) -> &dyn Any {
760 self
761 }
762 fn schema(&self) -> SchemaRef {
763 Arc::clone(&self.schema)
764 }
765 fn properties(&self) -> &PlanProperties {
766 &self.properties
767 }
768 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
769 vec![]
770 }
771 fn with_new_children(
772 self: Arc<Self>,
773 _children: Vec<Arc<dyn ExecutionPlan>>,
774 ) -> DFResult<Arc<dyn ExecutionPlan>> {
775 Ok(self)
776 }
777 fn execute(
778 &self,
779 _partition: usize,
780 _context: Arc<TaskContext>,
781 ) -> DFResult<SendableRecordBatchStream> {
782 Ok(Box::pin(MemoryStream::try_new(
783 self.batches.clone(),
784 Arc::clone(&self.schema),
785 None,
786 )?))
787 }
788 }
789
790 async fn execute_fold(
791 input: Arc<dyn ExecutionPlan>,
792 key_indices: Vec<usize>,
793 fold_bindings: Vec<FoldBinding>,
794 ) -> RecordBatch {
795 let exec = FoldExec::new(input, key_indices, fold_bindings, false, 1e-15);
796 let ctx = SessionContext::new();
797 let task_ctx = ctx.task_ctx();
798 let stream = exec.execute(0, task_ctx).unwrap();
799 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
800 .await
801 .unwrap();
802 if batches.is_empty() {
803 RecordBatch::new_empty(exec.schema())
804 } else {
805 arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
806 }
807 }
808
809 #[tokio::test]
810 async fn test_sum_single_group() {
811 let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
812 let input = make_memory_exec(batch);
813 let result = execute_fold(
814 input,
815 vec![0],
816 vec![FoldBinding {
817 output_name: "total".to_string(),
818 kind: FoldAggKind::Sum,
819 input_col_index: 1,
820 }],
821 )
822 .await;
823
824 assert_eq!(result.num_rows(), 1);
825 let totals = result
826 .column(1)
827 .as_any()
828 .downcast_ref::<Float64Array>()
829 .unwrap();
830 assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
831 }
832
833 #[tokio::test]
834 async fn test_count_non_null() {
835 let schema = Arc::new(Schema::new(vec![
836 Field::new("name", DataType::Utf8, true),
837 Field::new("value", DataType::Float64, true),
838 ]));
839 let batch = RecordBatch::try_new(
840 schema,
841 vec![
842 Arc::new(StringArray::from(vec![Some("a"), Some("a"), Some("a")])),
843 Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])),
844 ],
845 )
846 .unwrap();
847 let input = make_memory_exec(batch);
848 let result = execute_fold(
849 input,
850 vec![0],
851 vec![FoldBinding {
852 output_name: "cnt".to_string(),
853 kind: FoldAggKind::Count,
854 input_col_index: 1,
855 }],
856 )
857 .await;
858
859 assert_eq!(result.num_rows(), 1);
860 let counts = result
861 .column(1)
862 .as_any()
863 .downcast_ref::<Int64Array>()
864 .unwrap();
865 assert_eq!(counts.value(0), 2); }
867
868 #[tokio::test]
869 async fn test_max_min() {
870 let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 5.0]);
871 let input_max = make_memory_exec(batch.clone());
872 let input_min = make_memory_exec(batch);
873
874 let result_max = execute_fold(
875 input_max,
876 vec![0],
877 vec![FoldBinding {
878 output_name: "mx".to_string(),
879 kind: FoldAggKind::Max,
880 input_col_index: 1,
881 }],
882 )
883 .await;
884 let result_min = execute_fold(
885 input_min,
886 vec![0],
887 vec![FoldBinding {
888 output_name: "mn".to_string(),
889 kind: FoldAggKind::Min,
890 input_col_index: 1,
891 }],
892 )
893 .await;
894
895 let max_vals = result_max
896 .column(1)
897 .as_any()
898 .downcast_ref::<Float64Array>()
899 .unwrap();
900 assert_eq!(max_vals.value(0), 5.0);
901
902 let min_vals = result_min
903 .column(1)
904 .as_any()
905 .downcast_ref::<Float64Array>()
906 .unwrap();
907 assert_eq!(min_vals.value(0), 1.0);
908 }
909
910 #[tokio::test]
911 async fn test_avg() {
912 let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![2.0, 4.0, 6.0, 8.0]);
913 let input = make_memory_exec(batch);
914 let result = execute_fold(
915 input,
916 vec![0],
917 vec![FoldBinding {
918 output_name: "average".to_string(),
919 kind: FoldAggKind::Avg,
920 input_col_index: 1,
921 }],
922 )
923 .await;
924
925 assert_eq!(result.num_rows(), 1);
926 let avgs = result
927 .column(1)
928 .as_any()
929 .downcast_ref::<Float64Array>()
930 .unwrap();
931 assert!((avgs.value(0) - 5.0).abs() < f64::EPSILON);
932 }
933
934 #[tokio::test]
935 async fn test_multiple_groups() {
936 let batch = make_test_batch(
937 vec!["a", "a", "b", "b", "b"],
938 vec![1.0, 2.0, 10.0, 20.0, 30.0],
939 );
940 let input = make_memory_exec(batch);
941 let result = execute_fold(
942 input,
943 vec![0],
944 vec![FoldBinding {
945 output_name: "total".to_string(),
946 kind: FoldAggKind::Sum,
947 input_col_index: 1,
948 }],
949 )
950 .await;
951
952 assert_eq!(result.num_rows(), 2);
953 let names = result
954 .column(0)
955 .as_any()
956 .downcast_ref::<StringArray>()
957 .unwrap();
958 let totals = result
959 .column(1)
960 .as_any()
961 .downcast_ref::<Float64Array>()
962 .unwrap();
963
964 for i in 0..2 {
965 match names.value(i) {
966 "a" => assert!((totals.value(i) - 3.0).abs() < f64::EPSILON),
967 "b" => assert!((totals.value(i) - 60.0).abs() < f64::EPSILON),
968 _ => panic!("unexpected name"),
969 }
970 }
971 }
972
973 #[tokio::test]
974 async fn test_empty_input() {
975 let schema = Arc::new(Schema::new(vec![
976 Field::new("name", DataType::Utf8, true),
977 Field::new("value", DataType::Float64, true),
978 ]));
979 let batch = RecordBatch::new_empty(schema);
980 let input = make_memory_exec(batch);
981 let result = execute_fold(
982 input,
983 vec![0],
984 vec![FoldBinding {
985 output_name: "total".to_string(),
986 kind: FoldAggKind::Sum,
987 input_col_index: 1,
988 }],
989 )
990 .await;
991
992 assert_eq!(result.num_rows(), 0);
993 }
994
995 #[tokio::test]
996 async fn test_multiple_bindings() {
997 let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
998 let input = make_memory_exec(batch);
999 let result = execute_fold(
1000 input,
1001 vec![0],
1002 vec![
1003 FoldBinding {
1004 output_name: "total".to_string(),
1005 kind: FoldAggKind::Sum,
1006 input_col_index: 1,
1007 },
1008 FoldBinding {
1009 output_name: "cnt".to_string(),
1010 kind: FoldAggKind::Count,
1011 input_col_index: 1,
1012 },
1013 FoldBinding {
1014 output_name: "mx".to_string(),
1015 kind: FoldAggKind::Max,
1016 input_col_index: 1,
1017 },
1018 ],
1019 )
1020 .await;
1021
1022 assert_eq!(result.num_rows(), 1);
1023 assert_eq!(result.num_columns(), 4); let totals = result
1026 .column(1)
1027 .as_any()
1028 .downcast_ref::<Float64Array>()
1029 .unwrap();
1030 assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
1031
1032 let counts = result
1033 .column(2)
1034 .as_any()
1035 .downcast_ref::<Int64Array>()
1036 .unwrap();
1037 assert_eq!(counts.value(0), 3);
1038
1039 let maxes = result
1040 .column(3)
1041 .as_any()
1042 .downcast_ref::<Float64Array>()
1043 .unwrap();
1044 assert_eq!(maxes.value(0), 3.0);
1045 }
1046
1047 #[tokio::test]
1050 async fn test_nor_single_group() {
1051 let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
1053 let input = make_memory_exec(batch);
1054 let result = execute_fold(
1055 input,
1056 vec![0],
1057 vec![FoldBinding {
1058 output_name: "prob".to_string(),
1059 kind: FoldAggKind::Nor,
1060 input_col_index: 1,
1061 }],
1062 )
1063 .await;
1064
1065 assert_eq!(result.num_rows(), 1);
1066 let vals = result
1067 .column(1)
1068 .as_any()
1069 .downcast_ref::<Float64Array>()
1070 .unwrap();
1071 assert!((vals.value(0) - 0.65).abs() < 1e-10);
1072 }
1073
1074 #[tokio::test]
1075 async fn test_nor_identity() {
1076 let batch = make_test_batch(vec!["a", "a"], vec![0.0, 0.0]);
1078 let input = make_memory_exec(batch);
1079 let result = execute_fold(
1080 input,
1081 vec![0],
1082 vec![FoldBinding {
1083 output_name: "prob".to_string(),
1084 kind: FoldAggKind::Nor,
1085 input_col_index: 1,
1086 }],
1087 )
1088 .await;
1089
1090 let vals = result
1091 .column(1)
1092 .as_any()
1093 .downcast_ref::<Float64Array>()
1094 .unwrap();
1095 assert!((vals.value(0) - 0.0).abs() < 1e-10);
1096 }
1097
1098 #[tokio::test]
1099 async fn test_nor_clamping() {
1100 let batch = make_test_batch(vec!["a", "a"], vec![-0.5, 1.5]);
1102 let input = make_memory_exec(batch);
1103 let result = execute_fold(
1104 input,
1105 vec![0],
1106 vec![FoldBinding {
1107 output_name: "prob".to_string(),
1108 kind: FoldAggKind::Nor,
1109 input_col_index: 1,
1110 }],
1111 )
1112 .await;
1113
1114 let vals = result
1115 .column(1)
1116 .as_any()
1117 .downcast_ref::<Float64Array>()
1118 .unwrap();
1119 assert!((vals.value(0) - 1.0).abs() < 1e-10);
1121 }
1122
1123 #[tokio::test]
1124 async fn test_nor_multiple_groups() {
1125 let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.3, 0.5, 0.1, 0.2]);
1126 let input = make_memory_exec(batch);
1127 let result = execute_fold(
1128 input,
1129 vec![0],
1130 vec![FoldBinding {
1131 output_name: "prob".to_string(),
1132 kind: FoldAggKind::Nor,
1133 input_col_index: 1,
1134 }],
1135 )
1136 .await;
1137
1138 assert_eq!(result.num_rows(), 2);
1139 let names = result
1140 .column(0)
1141 .as_any()
1142 .downcast_ref::<StringArray>()
1143 .unwrap();
1144 let vals = result
1145 .column(1)
1146 .as_any()
1147 .downcast_ref::<Float64Array>()
1148 .unwrap();
1149
1150 for i in 0..2 {
1151 match names.value(i) {
1152 "a" => assert!((vals.value(i) - 0.65).abs() < 1e-10),
1154 "b" => assert!((vals.value(i) - 0.28).abs() < 1e-10),
1156 _ => panic!("unexpected name"),
1157 }
1158 }
1159 }
1160
1161 #[tokio::test]
1164 async fn test_prod_single_group() {
1165 let batch = make_test_batch(vec!["a", "a"], vec![0.6, 0.8]);
1167 let input = make_memory_exec(batch);
1168 let result = execute_fold(
1169 input,
1170 vec![0],
1171 vec![FoldBinding {
1172 output_name: "prob".to_string(),
1173 kind: FoldAggKind::Prod,
1174 input_col_index: 1,
1175 }],
1176 )
1177 .await;
1178
1179 assert_eq!(result.num_rows(), 1);
1180 let vals = result
1181 .column(1)
1182 .as_any()
1183 .downcast_ref::<Float64Array>()
1184 .unwrap();
1185 assert!((vals.value(0) - 0.48).abs() < 1e-10);
1186 }
1187
1188 #[tokio::test]
1189 async fn test_prod_identity() {
1190 let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0]);
1192 let input = make_memory_exec(batch);
1193 let result = execute_fold(
1194 input,
1195 vec![0],
1196 vec![FoldBinding {
1197 output_name: "prob".to_string(),
1198 kind: FoldAggKind::Prod,
1199 input_col_index: 1,
1200 }],
1201 )
1202 .await;
1203
1204 let vals = result
1205 .column(1)
1206 .as_any()
1207 .downcast_ref::<Float64Array>()
1208 .unwrap();
1209 assert!((vals.value(0) - 1.0).abs() < 1e-10);
1210 }
1211
1212 #[tokio::test]
1213 async fn test_prod_zero_absorbing() {
1214 let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.0, 0.8]);
1216 let input = make_memory_exec(batch);
1217 let result = execute_fold(
1218 input,
1219 vec![0],
1220 vec![FoldBinding {
1221 output_name: "prob".to_string(),
1222 kind: FoldAggKind::Prod,
1223 input_col_index: 1,
1224 }],
1225 )
1226 .await;
1227
1228 let vals = result
1229 .column(1)
1230 .as_any()
1231 .downcast_ref::<Float64Array>()
1232 .unwrap();
1233 assert!((vals.value(0) - 0.0).abs() < 1e-10);
1234 }
1235
1236 #[tokio::test]
1237 async fn test_prod_underflow_protection() {
1238 let names: Vec<&str> = vec!["a"; 50];
1240 let values: Vec<f64> = vec![0.5; 50];
1241 let batch = make_test_batch(names, values);
1242 let input = make_memory_exec(batch);
1243 let result = execute_fold(
1244 input,
1245 vec![0],
1246 vec![FoldBinding {
1247 output_name: "prob".to_string(),
1248 kind: FoldAggKind::Prod,
1249 input_col_index: 1,
1250 }],
1251 )
1252 .await;
1253
1254 let vals = result
1255 .column(1)
1256 .as_any()
1257 .downcast_ref::<Float64Array>()
1258 .unwrap();
1259 let expected = 0.5_f64.powi(50); assert!(vals.value(0) > 0.0, "should not underflow to zero");
1261 assert!(
1262 (vals.value(0) - expected).abs() / expected < 1e-6,
1263 "result {} should be close to expected {}",
1264 vals.value(0),
1265 expected
1266 );
1267 }
1268
1269 fn make_nullable_test_batch(names: Vec<&str>, values: Vec<Option<f64>>) -> RecordBatch {
1272 let schema = Arc::new(Schema::new(vec![
1273 Field::new("name", DataType::Utf8, true),
1274 Field::new("value", DataType::Float64, true),
1275 ]));
1276 RecordBatch::try_new(
1277 schema,
1278 vec![
1279 Arc::new(StringArray::from(
1280 names.into_iter().map(Some).collect::<Vec<_>>(),
1281 )),
1282 Arc::new(Float64Array::from(values)),
1283 ],
1284 )
1285 .unwrap()
1286 }
1287
1288 #[tokio::test]
1289 async fn test_nor_single_element() {
1290 let batch = make_test_batch(vec!["a"], vec![0.7]);
1292 let input = make_memory_exec(batch);
1293 let result = execute_fold(
1294 input,
1295 vec![0],
1296 vec![FoldBinding {
1297 output_name: "prob".to_string(),
1298 kind: FoldAggKind::Nor,
1299 input_col_index: 1,
1300 }],
1301 )
1302 .await;
1303 let vals = result
1304 .column(1)
1305 .as_any()
1306 .downcast_ref::<Float64Array>()
1307 .unwrap();
1308 assert!((vals.value(0) - 0.7).abs() < 1e-10);
1309 }
1310
1311 #[tokio::test]
1312 async fn test_prod_single_element() {
1313 let batch = make_test_batch(vec!["a"], vec![0.7]);
1315 let input = make_memory_exec(batch);
1316 let result = execute_fold(
1317 input,
1318 vec![0],
1319 vec![FoldBinding {
1320 output_name: "prob".to_string(),
1321 kind: FoldAggKind::Prod,
1322 input_col_index: 1,
1323 }],
1324 )
1325 .await;
1326 let vals = result
1327 .column(1)
1328 .as_any()
1329 .downcast_ref::<Float64Array>()
1330 .unwrap();
1331 assert!((vals.value(0) - 0.7).abs() < 1e-10);
1332 }
1333
1334 #[tokio::test]
1335 async fn test_nor_three_elements() {
1336 let batch = make_test_batch(vec!["a", "a", "a"], vec![0.3, 0.4, 0.5]);
1338 let input = make_memory_exec(batch);
1339 let result = execute_fold(
1340 input,
1341 vec![0],
1342 vec![FoldBinding {
1343 output_name: "prob".to_string(),
1344 kind: FoldAggKind::Nor,
1345 input_col_index: 1,
1346 }],
1347 )
1348 .await;
1349 let vals = result
1350 .column(1)
1351 .as_any()
1352 .downcast_ref::<Float64Array>()
1353 .unwrap();
1354 assert!((vals.value(0) - 0.79).abs() < 1e-10);
1355 }
1356
1357 #[tokio::test]
1358 async fn test_nor_four_elements_spec_example() {
1359 let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![0.72, 0.54, 0.56, 0.42]);
1361 let input = make_memory_exec(batch);
1362 let result = execute_fold(
1363 input,
1364 vec![0],
1365 vec![FoldBinding {
1366 output_name: "prob".to_string(),
1367 kind: FoldAggKind::Nor,
1368 input_col_index: 1,
1369 }],
1370 )
1371 .await;
1372 let vals = result
1373 .column(1)
1374 .as_any()
1375 .downcast_ref::<Float64Array>()
1376 .unwrap();
1377 assert!(
1378 (vals.value(0) - 0.96713024).abs() < 1e-10,
1379 "expected 0.96713024, got {}",
1380 vals.value(0)
1381 );
1382 }
1383
1384 #[tokio::test]
1385 async fn test_prod_three_elements() {
1386 let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.5, 0.5]);
1388 let input = make_memory_exec(batch);
1389 let result = execute_fold(
1390 input,
1391 vec![0],
1392 vec![FoldBinding {
1393 output_name: "prob".to_string(),
1394 kind: FoldAggKind::Prod,
1395 input_col_index: 1,
1396 }],
1397 )
1398 .await;
1399 let vals = result
1400 .column(1)
1401 .as_any()
1402 .downcast_ref::<Float64Array>()
1403 .unwrap();
1404 assert!((vals.value(0) - 0.125).abs() < 1e-10);
1405 }
1406
1407 #[tokio::test]
1408 async fn test_nor_absorbing_element() {
1409 let batch = make_test_batch(vec!["a", "a"], vec![0.3, 1.0]);
1411 let input = make_memory_exec(batch);
1412 let result = execute_fold(
1413 input,
1414 vec![0],
1415 vec![FoldBinding {
1416 output_name: "prob".to_string(),
1417 kind: FoldAggKind::Nor,
1418 input_col_index: 1,
1419 }],
1420 )
1421 .await;
1422 let vals = result
1423 .column(1)
1424 .as_any()
1425 .downcast_ref::<Float64Array>()
1426 .unwrap();
1427 assert!((vals.value(0) - 1.0).abs() < 1e-10);
1428 }
1429
1430 #[tokio::test]
1431 async fn test_prod_clamping() {
1432 let batch = make_test_batch(vec!["a", "a"], vec![2.0, 0.5]);
1434 let input = make_memory_exec(batch);
1435 let result = execute_fold(
1436 input,
1437 vec![0],
1438 vec![FoldBinding {
1439 output_name: "prob".to_string(),
1440 kind: FoldAggKind::Prod,
1441 input_col_index: 1,
1442 }],
1443 )
1444 .await;
1445 let vals = result
1446 .column(1)
1447 .as_any()
1448 .downcast_ref::<Float64Array>()
1449 .unwrap();
1450 assert!((vals.value(0) - 0.5).abs() < 1e-10);
1451 }
1452
1453 #[tokio::test]
1454 async fn test_prod_multiple_groups() {
1455 let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.6, 0.8, 0.5, 0.5]);
1457 let input = make_memory_exec(batch);
1458 let result = execute_fold(
1459 input,
1460 vec![0],
1461 vec![FoldBinding {
1462 output_name: "prob".to_string(),
1463 kind: FoldAggKind::Prod,
1464 input_col_index: 1,
1465 }],
1466 )
1467 .await;
1468
1469 assert_eq!(result.num_rows(), 2);
1470 let names = result
1471 .column(0)
1472 .as_any()
1473 .downcast_ref::<StringArray>()
1474 .unwrap();
1475 let vals = result
1476 .column(1)
1477 .as_any()
1478 .downcast_ref::<Float64Array>()
1479 .unwrap();
1480 for i in 0..2 {
1481 match names.value(i) {
1482 "a" => assert!((vals.value(i) - 0.48).abs() < 1e-10),
1483 "b" => assert!((vals.value(i) - 0.25).abs() < 1e-10),
1484 _ => panic!("unexpected group name"),
1485 }
1486 }
1487 }
1488
1489 #[tokio::test]
1490 async fn test_nor_commutativity() {
1491 let fwd = make_test_batch(vec!["a", "a", "a"], vec![0.2, 0.5, 0.8]);
1493 let rev = make_test_batch(vec!["a", "a", "a"], vec![0.8, 0.5, 0.2]);
1494 let binding = vec![FoldBinding {
1495 output_name: "prob".to_string(),
1496 kind: FoldAggKind::Nor,
1497 input_col_index: 1,
1498 }];
1499 let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1500 let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1501 let v1 = r1
1502 .column(1)
1503 .as_any()
1504 .downcast_ref::<Float64Array>()
1505 .unwrap()
1506 .value(0);
1507 let v2 = r2
1508 .column(1)
1509 .as_any()
1510 .downcast_ref::<Float64Array>()
1511 .unwrap()
1512 .value(0);
1513 assert!((v1 - 0.92).abs() < 1e-10);
1514 assert!((v2 - 0.92).abs() < 1e-10);
1515 assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1516 }
1517
1518 #[tokio::test]
1519 async fn test_prod_commutativity() {
1520 let fwd = make_test_batch(vec!["a", "a"], vec![0.5, 0.25]);
1522 let rev = make_test_batch(vec!["a", "a"], vec![0.25, 0.5]);
1523 let binding = vec![FoldBinding {
1524 output_name: "prob".to_string(),
1525 kind: FoldAggKind::Prod,
1526 input_col_index: 1,
1527 }];
1528 let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
1529 let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
1530 let v1 = r1
1531 .column(1)
1532 .as_any()
1533 .downcast_ref::<Float64Array>()
1534 .unwrap()
1535 .value(0);
1536 let v2 = r2
1537 .column(1)
1538 .as_any()
1539 .downcast_ref::<Float64Array>()
1540 .unwrap()
1541 .value(0);
1542 assert!((v1 - 0.125).abs() < 1e-10);
1543 assert!((v2 - 0.125).abs() < 1e-10);
1544 assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
1545 }
1546
1547 #[tokio::test]
1548 async fn test_nor_boundary_near_zero() {
1549 let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1551 let input = make_memory_exec(batch);
1552 let result = execute_fold(
1553 input,
1554 vec![0],
1555 vec![FoldBinding {
1556 output_name: "prob".to_string(),
1557 kind: FoldAggKind::Nor,
1558 input_col_index: 1,
1559 }],
1560 )
1561 .await;
1562 let vals = result
1563 .column(1)
1564 .as_any()
1565 .downcast_ref::<Float64Array>()
1566 .unwrap();
1567 let expected = 1.0 - 0.999 * 0.998;
1568 assert!(
1569 (vals.value(0) - expected).abs() < 1e-10,
1570 "expected {}, got {}",
1571 expected,
1572 vals.value(0)
1573 );
1574 }
1575
1576 #[tokio::test]
1577 async fn test_nor_boundary_near_one() {
1578 let batch = make_test_batch(vec!["a", "a"], vec![0.999, 0.998]);
1580 let input = make_memory_exec(batch);
1581 let result = execute_fold(
1582 input,
1583 vec![0],
1584 vec![FoldBinding {
1585 output_name: "prob".to_string(),
1586 kind: FoldAggKind::Nor,
1587 input_col_index: 1,
1588 }],
1589 )
1590 .await;
1591 let vals = result
1592 .column(1)
1593 .as_any()
1594 .downcast_ref::<Float64Array>()
1595 .unwrap();
1596 let expected = 1.0 - 0.001 * 0.002;
1597 assert!(
1598 (vals.value(0) - expected).abs() < 1e-10,
1599 "expected {}, got {}",
1600 expected,
1601 vals.value(0)
1602 );
1603 }
1604
1605 #[tokio::test]
1606 async fn test_prod_boundary_near_zero() {
1607 let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
1609 let input = make_memory_exec(batch);
1610 let result = execute_fold(
1611 input,
1612 vec![0],
1613 vec![FoldBinding {
1614 output_name: "prob".to_string(),
1615 kind: FoldAggKind::Prod,
1616 input_col_index: 1,
1617 }],
1618 )
1619 .await;
1620 let vals = result
1621 .column(1)
1622 .as_any()
1623 .downcast_ref::<Float64Array>()
1624 .unwrap();
1625 assert!(
1626 (vals.value(0) - 2e-6).abs() < 1e-15,
1627 "expected 2e-6, got {}",
1628 vals.value(0)
1629 );
1630 }
1631
1632 #[tokio::test]
1633 async fn test_nor_empty_input() {
1634 let schema = Arc::new(Schema::new(vec![
1636 Field::new("name", DataType::Utf8, true),
1637 Field::new("value", DataType::Float64, true),
1638 ]));
1639 let batch = RecordBatch::new_empty(schema);
1640 let input = make_memory_exec(batch);
1641 let result = execute_fold(
1642 input,
1643 vec![0],
1644 vec![FoldBinding {
1645 output_name: "prob".to_string(),
1646 kind: FoldAggKind::Nor,
1647 input_col_index: 1,
1648 }],
1649 )
1650 .await;
1651 assert_eq!(result.num_rows(), 0);
1652 }
1653
1654 #[tokio::test]
1655 async fn test_nor_nan_handling() {
1656 let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::NAN]);
1658 let input = make_memory_exec(batch);
1659 let result = execute_fold(
1660 input,
1661 vec![0],
1662 vec![FoldBinding {
1663 output_name: "prob".to_string(),
1664 kind: FoldAggKind::Nor,
1665 input_col_index: 1,
1666 }],
1667 )
1668 .await;
1669 let vals = result
1670 .column(1)
1671 .as_any()
1672 .downcast_ref::<Float64Array>()
1673 .unwrap();
1674 assert!(vals.value(0).is_nan(), "NaN should propagate through MNOR");
1675 }
1676
1677 #[tokio::test]
1678 async fn test_prod_nan_handling() {
1679 let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::NAN]);
1681 let input = make_memory_exec(batch);
1682 let result = execute_fold(
1683 input,
1684 vec![0],
1685 vec![FoldBinding {
1686 output_name: "prob".to_string(),
1687 kind: FoldAggKind::Prod,
1688 input_col_index: 1,
1689 }],
1690 )
1691 .await;
1692 let vals = result
1693 .column(1)
1694 .as_any()
1695 .downcast_ref::<Float64Array>()
1696 .unwrap();
1697 assert!(vals.value(0).is_nan(), "NaN should propagate through MPROD");
1698 }
1699
1700 #[tokio::test]
1701 async fn test_prod_infinity_handling() {
1702 let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::INFINITY]);
1704 let input = make_memory_exec(batch);
1705 let result = execute_fold(
1706 input,
1707 vec![0],
1708 vec![FoldBinding {
1709 output_name: "prob".to_string(),
1710 kind: FoldAggKind::Prod,
1711 input_col_index: 1,
1712 }],
1713 )
1714 .await;
1715 let vals = result
1716 .column(1)
1717 .as_any()
1718 .downcast_ref::<Float64Array>()
1719 .unwrap();
1720 assert!((vals.value(0) - 0.5).abs() < 1e-10);
1721 }
1722
1723 #[tokio::test]
1724 async fn test_nor_infinity_handling() {
1725 let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::INFINITY]);
1727 let input = make_memory_exec(batch);
1728 let result = execute_fold(
1729 input,
1730 vec![0],
1731 vec![FoldBinding {
1732 output_name: "prob".to_string(),
1733 kind: FoldAggKind::Nor,
1734 input_col_index: 1,
1735 }],
1736 )
1737 .await;
1738 let vals = result
1739 .column(1)
1740 .as_any()
1741 .downcast_ref::<Float64Array>()
1742 .unwrap();
1743 assert!((vals.value(0) - 1.0).abs() < 1e-10);
1744 }
1745
1746 #[tokio::test]
1747 async fn test_nor_all_null_values() {
1748 let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1750 let input = make_memory_exec(batch);
1751 let result = execute_fold(
1752 input,
1753 vec![0],
1754 vec![FoldBinding {
1755 output_name: "prob".to_string(),
1756 kind: FoldAggKind::Nor,
1757 input_col_index: 1,
1758 }],
1759 )
1760 .await;
1761 assert_eq!(result.num_rows(), 1);
1762 let vals = result
1763 .column(1)
1764 .as_any()
1765 .downcast_ref::<Float64Array>()
1766 .unwrap();
1767 assert!(vals.is_null(0), "all-null MNOR should produce null");
1768 }
1769
1770 #[tokio::test]
1771 async fn test_prod_all_null_values() {
1772 let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
1774 let input = make_memory_exec(batch);
1775 let result = execute_fold(
1776 input,
1777 vec![0],
1778 vec![FoldBinding {
1779 output_name: "prob".to_string(),
1780 kind: FoldAggKind::Prod,
1781 input_col_index: 1,
1782 }],
1783 )
1784 .await;
1785 assert_eq!(result.num_rows(), 1);
1786 let vals = result
1787 .column(1)
1788 .as_any()
1789 .downcast_ref::<Float64Array>()
1790 .unwrap();
1791 assert!(vals.is_null(0), "all-null MPROD should produce null");
1792 }
1793
1794 #[tokio::test]
1795 async fn test_nor_mixed_null_values() {
1796 let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.3), None, Some(0.5)]);
1798 let input = make_memory_exec(batch);
1799 let result = execute_fold(
1800 input,
1801 vec![0],
1802 vec![FoldBinding {
1803 output_name: "prob".to_string(),
1804 kind: FoldAggKind::Nor,
1805 input_col_index: 1,
1806 }],
1807 )
1808 .await;
1809 let vals = result
1810 .column(1)
1811 .as_any()
1812 .downcast_ref::<Float64Array>()
1813 .unwrap();
1814 assert!((vals.value(0) - 0.65).abs() < 1e-10);
1815 }
1816
1817 #[tokio::test]
1818 async fn test_prod_mixed_null_values() {
1819 let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.6), None, Some(0.8)]);
1821 let input = make_memory_exec(batch);
1822 let result = execute_fold(
1823 input,
1824 vec![0],
1825 vec![FoldBinding {
1826 output_name: "prob".to_string(),
1827 kind: FoldAggKind::Prod,
1828 input_col_index: 1,
1829 }],
1830 )
1831 .await;
1832 let vals = result
1833 .column(1)
1834 .as_any()
1835 .downcast_ref::<Float64Array>()
1836 .unwrap();
1837 assert!((vals.value(0) - 0.48).abs() < 1e-10);
1838 }
1839
1840 #[tokio::test]
1841 async fn test_nor_many_small_values() {
1842 let names: Vec<&str> = vec!["a"; 20];
1844 let values: Vec<f64> = vec![0.1; 20];
1845 let batch = make_test_batch(names, values);
1846 let input = make_memory_exec(batch);
1847 let result = execute_fold(
1848 input,
1849 vec![0],
1850 vec![FoldBinding {
1851 output_name: "prob".to_string(),
1852 kind: FoldAggKind::Nor,
1853 input_col_index: 1,
1854 }],
1855 )
1856 .await;
1857 let vals = result
1858 .column(1)
1859 .as_any()
1860 .downcast_ref::<Float64Array>()
1861 .unwrap();
1862 let expected = 1.0 - 0.9_f64.powi(20);
1863 assert!(
1864 (vals.value(0) - expected).abs() < 1e-10,
1865 "expected {}, got {}",
1866 expected,
1867 vals.value(0)
1868 );
1869 }
1870
1871 #[test]
1874 fn test_is_monotonic() {
1875 assert!(FoldAggKind::Sum.is_monotonic());
1876 assert!(FoldAggKind::Max.is_monotonic());
1877 assert!(FoldAggKind::Min.is_monotonic());
1878 assert!(FoldAggKind::Count.is_monotonic());
1879 assert!(FoldAggKind::Nor.is_monotonic());
1880 assert!(FoldAggKind::Prod.is_monotonic());
1881 assert!(!FoldAggKind::Avg.is_monotonic());
1882 assert!(!FoldAggKind::Collect.is_monotonic());
1883 }
1884
1885 #[test]
1886 fn test_monotonicity_direction() {
1887 use super::MonotonicDirection;
1888 assert_eq!(
1889 FoldAggKind::Sum.monotonicity_direction(),
1890 Some(MonotonicDirection::NonDecreasing)
1891 );
1892 assert_eq!(
1893 FoldAggKind::Max.monotonicity_direction(),
1894 Some(MonotonicDirection::NonDecreasing)
1895 );
1896 assert_eq!(
1897 FoldAggKind::Count.monotonicity_direction(),
1898 Some(MonotonicDirection::NonDecreasing)
1899 );
1900 assert_eq!(
1901 FoldAggKind::Nor.monotonicity_direction(),
1902 Some(MonotonicDirection::NonDecreasing)
1903 );
1904 assert_eq!(
1905 FoldAggKind::Min.monotonicity_direction(),
1906 Some(MonotonicDirection::NonIncreasing)
1907 );
1908 assert_eq!(
1909 FoldAggKind::Prod.monotonicity_direction(),
1910 Some(MonotonicDirection::NonIncreasing)
1911 );
1912 assert_eq!(FoldAggKind::Avg.monotonicity_direction(), None);
1913 assert_eq!(FoldAggKind::Collect.monotonicity_direction(), None);
1914 }
1915
1916 #[test]
1917 fn test_identity_values() {
1918 assert_eq!(FoldAggKind::Sum.identity(), Some(0.0));
1919 assert_eq!(FoldAggKind::Count.identity(), Some(0.0));
1920 assert_eq!(FoldAggKind::Nor.identity(), Some(0.0));
1921 assert_eq!(FoldAggKind::Max.identity(), Some(f64::NEG_INFINITY));
1922 assert_eq!(FoldAggKind::Min.identity(), Some(f64::INFINITY));
1923 assert_eq!(FoldAggKind::Prod.identity(), Some(1.0));
1924 assert_eq!(FoldAggKind::Avg.identity(), None);
1925 assert_eq!(FoldAggKind::Collect.identity(), None);
1926 }
1927
1928 async fn execute_fold_strict(
1931 input: Arc<dyn ExecutionPlan>,
1932 key_indices: Vec<usize>,
1933 fold_bindings: Vec<FoldBinding>,
1934 strict: bool,
1935 ) -> DFResult<RecordBatch> {
1936 let exec = FoldExec::new(input, key_indices, fold_bindings, strict, 1e-15);
1937 let ctx = SessionContext::new();
1938 let task_ctx = ctx.task_ctx();
1939 let stream = exec.execute(0, task_ctx).unwrap();
1940 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream).await?;
1941 if batches.is_empty() {
1942 Ok(RecordBatch::new_empty(exec.schema()))
1943 } else {
1944 arrow::compute::concat_batches(&exec.schema(), &batches).map_err(arrow_err)
1945 }
1946 }
1947
1948 #[tokio::test]
1949 async fn test_nor_strict_rejects_above_one() {
1950 let batch = make_test_batch(vec!["a"], vec![1.5]);
1951 let input = make_memory_exec(batch);
1952 let result = execute_fold_strict(
1953 input,
1954 vec![0],
1955 vec![FoldBinding {
1956 output_name: "p".into(),
1957 kind: FoldAggKind::Nor,
1958 input_col_index: 1,
1959 }],
1960 true,
1961 )
1962 .await;
1963 assert!(result.is_err());
1964 let err = result.unwrap_err().to_string();
1965 assert!(
1966 err.contains("strict_probability_domain"),
1967 "Expected strict error, got: {}",
1968 err
1969 );
1970 }
1971
1972 #[tokio::test]
1973 async fn test_nor_strict_rejects_negative() {
1974 let batch = make_test_batch(vec!["a"], vec![-0.1]);
1975 let input = make_memory_exec(batch);
1976 let result = execute_fold_strict(
1977 input,
1978 vec![0],
1979 vec![FoldBinding {
1980 output_name: "p".into(),
1981 kind: FoldAggKind::Nor,
1982 input_col_index: 1,
1983 }],
1984 true,
1985 )
1986 .await;
1987 assert!(result.is_err());
1988 let err = result.unwrap_err().to_string();
1989 assert!(
1990 err.contains("strict_probability_domain"),
1991 "Expected strict error, got: {}",
1992 err
1993 );
1994 }
1995
1996 #[tokio::test]
1997 async fn test_prod_strict_rejects_above_one() {
1998 let batch = make_test_batch(vec!["a"], vec![2.0]);
1999 let input = make_memory_exec(batch);
2000 let result = execute_fold_strict(
2001 input,
2002 vec![0],
2003 vec![FoldBinding {
2004 output_name: "p".into(),
2005 kind: FoldAggKind::Prod,
2006 input_col_index: 1,
2007 }],
2008 true,
2009 )
2010 .await;
2011 assert!(result.is_err());
2012 let err = result.unwrap_err().to_string();
2013 assert!(
2014 err.contains("strict_probability_domain"),
2015 "Expected strict error, got: {}",
2016 err
2017 );
2018 }
2019
2020 #[tokio::test]
2021 async fn test_prod_strict_rejects_negative() {
2022 let batch = make_test_batch(vec!["a"], vec![-0.5]);
2023 let input = make_memory_exec(batch);
2024 let result = execute_fold_strict(
2025 input,
2026 vec![0],
2027 vec![FoldBinding {
2028 output_name: "p".into(),
2029 kind: FoldAggKind::Prod,
2030 input_col_index: 1,
2031 }],
2032 true,
2033 )
2034 .await;
2035 assert!(result.is_err());
2036 let err = result.unwrap_err().to_string();
2037 assert!(
2038 err.contains("strict_probability_domain"),
2039 "Expected strict error, got: {}",
2040 err
2041 );
2042 }
2043
2044 #[tokio::test]
2045 async fn test_nor_strict_accepts_valid() {
2046 let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
2047 let input = make_memory_exec(batch);
2048 let result = execute_fold_strict(
2049 input,
2050 vec![0],
2051 vec![FoldBinding {
2052 output_name: "p".into(),
2053 kind: FoldAggKind::Nor,
2054 input_col_index: 1,
2055 }],
2056 true,
2057 )
2058 .await;
2059 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
2060 let batch = result.unwrap();
2061 let vals = batch
2062 .column(1)
2063 .as_any()
2064 .downcast_ref::<Float64Array>()
2065 .unwrap();
2066 let expected = 0.65; assert!(
2068 (vals.value(0) - expected).abs() < 1e-10,
2069 "expected {}, got {}",
2070 expected,
2071 vals.value(0)
2072 );
2073 }
2074
2075 #[tokio::test]
2076 async fn test_count_all_groups_by_key() {
2077 let batch = make_test_batch(vec!["a", "a", "b"], vec![10.0, 20.0, 30.0]);
2079 let input = make_memory_exec(batch);
2080 let result = execute_fold(
2081 input,
2082 vec![0],
2083 vec![FoldBinding {
2084 output_name: "cnt".to_string(),
2085 kind: FoldAggKind::CountAll,
2086 input_col_index: 0, }],
2088 )
2089 .await;
2090
2091 assert_eq!(result.num_rows(), 2, "Should have 2 groups");
2092 let counts = result
2093 .column(1)
2094 .as_any()
2095 .downcast_ref::<Int64Array>()
2096 .unwrap();
2097 assert_eq!(counts.value(0), 2, "Group 'a' should have count 2");
2098 assert_eq!(counts.value(1), 1, "Group 'b' should have count 1");
2099 }
2100}