1use anyhow::Result;
4
5use super::{
6 AggregateFunction, AggregateState, AvgState, MinMaxState, ModeState, PercentileState,
7 StringAggState, SumState, VarianceState,
8};
9use crate::data::datatable::DataValue;
10
11pub struct CountStarFunction;
13
14impl AggregateFunction for CountStarFunction {
15 fn name(&self) -> &'static str {
16 "COUNT_STAR"
17 }
18
19 fn init(&self) -> AggregateState {
20 AggregateState::Count(0)
21 }
22
23 fn accumulate(&self, state: &mut AggregateState, _value: &DataValue) -> Result<()> {
24 if let AggregateState::Count(ref mut count) = state {
25 *count += 1;
26 }
27 Ok(())
28 }
29
30 fn finalize(&self, state: AggregateState) -> DataValue {
31 if let AggregateState::Count(count) = state {
32 DataValue::Integer(count)
33 } else {
34 DataValue::Null
35 }
36 }
37}
38
39pub struct CountFunction;
41
42impl AggregateFunction for CountFunction {
43 fn name(&self) -> &'static str {
44 "COUNT"
45 }
46
47 fn init(&self) -> AggregateState {
48 AggregateState::Count(0)
49 }
50
51 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
52 if let AggregateState::Count(ref mut count) = state {
53 if !matches!(value, DataValue::Null) {
54 *count += 1;
55 }
56 }
57 Ok(())
58 }
59
60 fn finalize(&self, state: AggregateState) -> DataValue {
61 if let AggregateState::Count(count) = state {
62 DataValue::Integer(count)
63 } else {
64 DataValue::Null
65 }
66 }
67}
68
69pub struct SumFunction;
71
72impl AggregateFunction for SumFunction {
73 fn name(&self) -> &'static str {
74 "SUM"
75 }
76
77 fn init(&self) -> AggregateState {
78 AggregateState::Sum(SumState::new())
79 }
80
81 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
82 if let AggregateState::Sum(ref mut sum_state) = state {
83 sum_state.add(value)?;
84 }
85 Ok(())
86 }
87
88 fn finalize(&self, state: AggregateState) -> DataValue {
89 if let AggregateState::Sum(sum_state) = state {
90 sum_state.finalize()
91 } else {
92 DataValue::Null
93 }
94 }
95
96 fn requires_numeric(&self) -> bool {
97 true
98 }
99}
100
101pub struct AvgFunction;
103
104impl AggregateFunction for AvgFunction {
105 fn name(&self) -> &'static str {
106 "AVG"
107 }
108
109 fn init(&self) -> AggregateState {
110 AggregateState::Avg(AvgState::new())
111 }
112
113 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
114 if let AggregateState::Avg(ref mut avg_state) = state {
115 avg_state.add(value)?;
116 }
117 Ok(())
118 }
119
120 fn finalize(&self, state: AggregateState) -> DataValue {
121 if let AggregateState::Avg(avg_state) = state {
122 avg_state.finalize()
123 } else {
124 DataValue::Null
125 }
126 }
127
128 fn requires_numeric(&self) -> bool {
129 true
130 }
131}
132
133pub struct MinFunction;
135
136impl AggregateFunction for MinFunction {
137 fn name(&self) -> &'static str {
138 "MIN"
139 }
140
141 fn init(&self) -> AggregateState {
142 AggregateState::MinMax(MinMaxState::new(true))
143 }
144
145 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
146 if let AggregateState::MinMax(ref mut minmax_state) = state {
147 minmax_state.add(value)?;
148 }
149 Ok(())
150 }
151
152 fn finalize(&self, state: AggregateState) -> DataValue {
153 if let AggregateState::MinMax(minmax_state) = state {
154 minmax_state.finalize()
155 } else {
156 DataValue::Null
157 }
158 }
159}
160
161pub struct MaxFunction;
163
164impl AggregateFunction for MaxFunction {
165 fn name(&self) -> &'static str {
166 "MAX"
167 }
168
169 fn init(&self) -> AggregateState {
170 AggregateState::MinMax(MinMaxState::new(false))
171 }
172
173 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
174 if let AggregateState::MinMax(ref mut minmax_state) = state {
175 minmax_state.add(value)?;
176 }
177 Ok(())
178 }
179
180 fn finalize(&self, state: AggregateState) -> DataValue {
181 if let AggregateState::MinMax(minmax_state) = state {
182 minmax_state.finalize()
183 } else {
184 DataValue::Null
185 }
186 }
187}
188
189pub struct VarianceFunction;
191
192impl AggregateFunction for VarianceFunction {
193 fn name(&self) -> &'static str {
194 "VARIANCE"
195 }
196
197 fn init(&self) -> AggregateState {
198 AggregateState::Variance(VarianceState::new())
199 }
200
201 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
202 if let AggregateState::Variance(ref mut var_state) = state {
203 var_state.add(value)?;
204 }
205 Ok(())
206 }
207
208 fn finalize(&self, state: AggregateState) -> DataValue {
209 if let AggregateState::Variance(var_state) = state {
210 var_state.finalize_variance()
211 } else {
212 DataValue::Null
213 }
214 }
215
216 fn requires_numeric(&self) -> bool {
217 true
218 }
219}
220
221pub struct StdDevFunction;
223
224impl AggregateFunction for StdDevFunction {
225 fn name(&self) -> &'static str {
226 "STDDEV"
227 }
228
229 fn init(&self) -> AggregateState {
230 AggregateState::Variance(VarianceState::new())
231 }
232
233 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
234 if let AggregateState::Variance(ref mut var_state) = state {
235 var_state.add(value)?;
236 }
237 Ok(())
238 }
239
240 fn finalize(&self, state: AggregateState) -> DataValue {
241 if let AggregateState::Variance(var_state) = state {
242 var_state.finalize_stddev()
243 } else {
244 DataValue::Null
245 }
246 }
247
248 fn requires_numeric(&self) -> bool {
249 true
250 }
251}
252
253pub struct StringAggFunction;
255
256impl AggregateFunction for StringAggFunction {
257 fn name(&self) -> &'static str {
258 "STRING_AGG"
259 }
260
261 fn init(&self) -> AggregateState {
262 AggregateState::StringAgg(StringAggState::new(","))
263 }
264
265 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
266 if let AggregateState::StringAgg(ref mut agg_state) = state {
267 agg_state.add(value)?;
268 }
269 Ok(())
270 }
271
272 fn finalize(&self, state: AggregateState) -> DataValue {
273 if let AggregateState::StringAgg(agg_state) = state {
274 agg_state.finalize()
275 } else {
276 DataValue::Null
277 }
278 }
279}
280
281pub struct MedianFunction;
283
284impl AggregateFunction for MedianFunction {
285 fn name(&self) -> &'static str {
286 "MEDIAN"
287 }
288
289 fn init(&self) -> AggregateState {
290 AggregateState::CollectList(Vec::new())
291 }
292
293 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
294 if let AggregateState::CollectList(ref mut values) = state {
295 if !matches!(value, DataValue::Null) {
297 values.push(value.clone());
298 }
299 }
300 Ok(())
301 }
302
303 fn finalize(&self, state: AggregateState) -> DataValue {
304 if let AggregateState::CollectList(mut values) = state {
305 if values.is_empty() {
306 return DataValue::Null;
307 }
308
309 values.sort_by(|a, b| {
311 use std::cmp::Ordering;
312 match (a, b) {
313 (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
314 (DataValue::Float(a), DataValue::Float(b)) => {
315 a.partial_cmp(b).unwrap_or(Ordering::Equal)
316 }
317 (DataValue::Integer(a), DataValue::Float(b)) => {
318 (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
319 }
320 (DataValue::Float(a), DataValue::Integer(b)) => {
321 a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
322 }
323 (DataValue::String(a), DataValue::String(b)) => a.cmp(b),
324 (DataValue::InternedString(a), DataValue::InternedString(b)) => a.cmp(b),
325 (DataValue::String(a), DataValue::InternedString(b)) => a.cmp(&**b),
326 (DataValue::InternedString(a), DataValue::String(b)) => (**a).cmp(b),
327 _ => Ordering::Equal,
328 }
329 });
330
331 let len = values.len();
332 if len % 2 == 1 {
333 values[len / 2].clone()
335 } else {
336 let mid1 = &values[len / 2 - 1];
338 let mid2 = &values[len / 2];
339
340 match (mid1, mid2) {
342 (DataValue::Integer(a), DataValue::Integer(b)) => {
343 let avg = (*a + *b) as f64 / 2.0;
344 if avg.fract() == 0.0 {
345 DataValue::Integer(avg as i64)
346 } else {
347 DataValue::Float(avg)
348 }
349 }
350 (DataValue::Float(a), DataValue::Float(b)) => DataValue::Float((a + b) / 2.0),
351 (DataValue::Integer(a), DataValue::Float(b)) => {
352 DataValue::Float((*a as f64 + b) / 2.0)
353 }
354 (DataValue::Float(a), DataValue::Integer(b)) => {
355 DataValue::Float((a + *b as f64) / 2.0)
356 }
357 _ => mid1.clone(),
359 }
360 }
361 } else {
362 DataValue::Null
363 }
364 }
365
366 fn requires_numeric(&self) -> bool {
367 false }
369}
370
371pub struct PercentileFunction;
373
374impl AggregateFunction for PercentileFunction {
375 fn name(&self) -> &'static str {
376 "PERCENTILE"
377 }
378
379 fn init(&self) -> AggregateState {
380 AggregateState::Percentile(PercentileState::new(50.0))
381 }
382
383 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
384 if let AggregateState::Percentile(ref mut percentile_state) = state {
385 percentile_state.add(value)?;
386 }
387 Ok(())
388 }
389
390 fn finalize(&self, state: AggregateState) -> DataValue {
391 if let AggregateState::Percentile(percentile_state) = state {
392 percentile_state.finalize()
393 } else {
394 DataValue::Null
395 }
396 }
397
398 fn requires_numeric(&self) -> bool {
399 true }
401}
402
403pub struct ModeFunction;
405
406pub struct StdDevPopFunction;
408
409pub struct StdDevSampFunction;
411
412pub struct VarPopFunction;
414
415pub struct VarSampFunction;
417
418impl AggregateFunction for StdDevPopFunction {
419 fn name(&self) -> &'static str {
420 "STDDEV_POP"
421 }
422
423 fn init(&self) -> AggregateState {
424 AggregateState::Variance(VarianceState::new())
425 }
426
427 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
428 if let AggregateState::Variance(ref mut var_state) = state {
429 var_state.add(value)?;
430 }
431 Ok(())
432 }
433
434 fn finalize(&self, state: AggregateState) -> DataValue {
435 if let AggregateState::Variance(var_state) = state {
436 var_state.finalize_stddev()
437 } else {
438 DataValue::Null
439 }
440 }
441
442 fn requires_numeric(&self) -> bool {
443 true
444 }
445}
446
447impl AggregateFunction for StdDevSampFunction {
448 fn name(&self) -> &'static str {
449 "STDDEV_SAMP"
450 }
451
452 fn init(&self) -> AggregateState {
453 AggregateState::Variance(VarianceState::new())
454 }
455
456 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
457 if let AggregateState::Variance(ref mut var_state) = state {
458 var_state.add(value)?;
459 }
460 Ok(())
461 }
462
463 fn finalize(&self, state: AggregateState) -> DataValue {
464 if let AggregateState::Variance(var_state) = state {
465 var_state.finalize_stddev_sample()
466 } else {
467 DataValue::Null
468 }
469 }
470
471 fn requires_numeric(&self) -> bool {
472 true
473 }
474}
475
476impl AggregateFunction for VarPopFunction {
477 fn name(&self) -> &'static str {
478 "VAR_POP"
479 }
480
481 fn init(&self) -> AggregateState {
482 AggregateState::Variance(VarianceState::new())
483 }
484
485 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
486 if let AggregateState::Variance(ref mut var_state) = state {
487 var_state.add(value)?;
488 }
489 Ok(())
490 }
491
492 fn finalize(&self, state: AggregateState) -> DataValue {
493 if let AggregateState::Variance(var_state) = state {
494 var_state.finalize_variance()
495 } else {
496 DataValue::Null
497 }
498 }
499
500 fn requires_numeric(&self) -> bool {
501 true
502 }
503}
504
505impl AggregateFunction for VarSampFunction {
506 fn name(&self) -> &'static str {
507 "VAR_SAMP"
508 }
509
510 fn init(&self) -> AggregateState {
511 AggregateState::Variance(VarianceState::new())
512 }
513
514 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
515 if let AggregateState::Variance(ref mut var_state) = state {
516 var_state.add(value)?;
517 }
518 Ok(())
519 }
520
521 fn finalize(&self, state: AggregateState) -> DataValue {
522 if let AggregateState::Variance(var_state) = state {
523 var_state.finalize_variance_sample()
524 } else {
525 DataValue::Null
526 }
527 }
528
529 fn requires_numeric(&self) -> bool {
530 true
531 }
532}
533
534impl AggregateFunction for ModeFunction {
535 fn name(&self) -> &'static str {
536 "MODE"
537 }
538
539 fn init(&self) -> AggregateState {
540 AggregateState::Mode(ModeState::new())
541 }
542
543 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
544 if let AggregateState::Mode(ref mut mode_state) = state {
545 mode_state.add(value)?;
546 }
547 Ok(())
548 }
549
550 fn finalize(&self, state: AggregateState) -> DataValue {
551 if let AggregateState::Mode(mode_state) = state {
552 mode_state.finalize()
553 } else {
554 DataValue::Null
555 }
556 }
557
558 fn requires_numeric(&self) -> bool {
559 false }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_count_star() {
569 let func = CountStarFunction;
570 let mut state = func.init();
571
572 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
574 func.accumulate(&mut state, &DataValue::Null).unwrap();
575 func.accumulate(&mut state, &DataValue::String("test".to_string()))
576 .unwrap();
577
578 let result = func.finalize(state);
579 assert_eq!(result, DataValue::Integer(3));
580 }
581
582 #[test]
583 fn test_count_column() {
584 let func = CountFunction;
585 let mut state = func.init();
586
587 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
589 func.accumulate(&mut state, &DataValue::Null).unwrap();
590 func.accumulate(&mut state, &DataValue::String("test".to_string()))
591 .unwrap();
592 func.accumulate(&mut state, &DataValue::Null).unwrap();
593
594 let result = func.finalize(state);
595 assert_eq!(result, DataValue::Integer(2));
596 }
597
598 #[test]
599 fn test_sum_integers() {
600 let func = SumFunction;
601 let mut state = func.init();
602
603 func.accumulate(&mut state, &DataValue::Integer(10))
604 .unwrap();
605 func.accumulate(&mut state, &DataValue::Integer(20))
606 .unwrap();
607 func.accumulate(&mut state, &DataValue::Integer(30))
608 .unwrap();
609 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
612 assert_eq!(result, DataValue::Integer(60));
613 }
614
615 #[test]
616 fn test_sum_mixed() {
617 let func = SumFunction;
618 let mut state = func.init();
619
620 func.accumulate(&mut state, &DataValue::Integer(10))
621 .unwrap();
622 func.accumulate(&mut state, &DataValue::Float(20.5))
623 .unwrap(); func.accumulate(&mut state, &DataValue::Integer(30))
625 .unwrap();
626
627 let result = func.finalize(state);
628 match result {
629 DataValue::Float(f) => assert!((f - 60.5).abs() < 0.001),
630 _ => panic!("Expected Float result"),
631 }
632 }
633
634 #[test]
635 fn test_avg() {
636 let func = AvgFunction;
637 let mut state = func.init();
638
639 func.accumulate(&mut state, &DataValue::Integer(10))
640 .unwrap();
641 func.accumulate(&mut state, &DataValue::Integer(20))
642 .unwrap();
643 func.accumulate(&mut state, &DataValue::Integer(30))
644 .unwrap();
645 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
648 match result {
649 DataValue::Float(f) => assert!((f - 20.0).abs() < 0.001),
650 _ => panic!("Expected Float result"),
651 }
652 }
653
654 #[test]
655 fn test_min() {
656 let func = MinFunction;
657 let mut state = func.init();
658
659 func.accumulate(&mut state, &DataValue::Integer(30))
660 .unwrap();
661 func.accumulate(&mut state, &DataValue::Integer(10))
662 .unwrap();
663 func.accumulate(&mut state, &DataValue::Integer(20))
664 .unwrap();
665 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
668 assert_eq!(result, DataValue::Integer(10));
669 }
670
671 #[test]
672 fn test_max() {
673 let func = MaxFunction;
674 let mut state = func.init();
675
676 func.accumulate(&mut state, &DataValue::Integer(10))
677 .unwrap();
678 func.accumulate(&mut state, &DataValue::Integer(30))
679 .unwrap();
680 func.accumulate(&mut state, &DataValue::Integer(20))
681 .unwrap();
682 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
685 assert_eq!(result, DataValue::Integer(30));
686 }
687
688 #[test]
689 fn test_max_strings() {
690 let func = MaxFunction;
691 let mut state = func.init();
692
693 func.accumulate(&mut state, &DataValue::String("apple".to_string()))
694 .unwrap();
695 func.accumulate(&mut state, &DataValue::String("zebra".to_string()))
696 .unwrap();
697 func.accumulate(&mut state, &DataValue::String("banana".to_string()))
698 .unwrap();
699
700 let result = func.finalize(state);
701 assert_eq!(result, DataValue::String("zebra".to_string()));
702 }
703
704 #[test]
705 fn test_variance() {
706 let func = VarianceFunction;
707 let mut state = func.init();
708
709 func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
712 func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
713 func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
714 func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
715 func.accumulate(&mut state, &DataValue::Integer(10))
716 .unwrap();
717
718 let result = func.finalize(state);
719 match result {
720 DataValue::Float(f) => assert!((f - 8.0).abs() < 0.001),
721 _ => panic!("Expected Float result"),
722 }
723 }
724
725 #[test]
726 fn test_stddev() {
727 let func = StdDevFunction;
728 let mut state = func.init();
729
730 func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
733 func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
734 func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
735 func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
736 func.accumulate(&mut state, &DataValue::Integer(10))
737 .unwrap();
738
739 let result = func.finalize(state);
740 match result {
741 DataValue::Float(f) => assert!((f - 2.8284271247461903).abs() < 0.001),
742 _ => panic!("Expected Float result"),
743 }
744 }
745
746 #[test]
747 fn test_variance_with_nulls() {
748 let func = VarianceFunction;
749 let mut state = func.init();
750
751 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
752 func.accumulate(&mut state, &DataValue::Null).unwrap(); func.accumulate(&mut state, &DataValue::Integer(10))
754 .unwrap();
755 func.accumulate(&mut state, &DataValue::Integer(15))
756 .unwrap();
757
758 let result = func.finalize(state);
759 match result {
760 DataValue::Float(f) => {
761 assert!((f - 16.666666666666668).abs() < 0.001);
764 }
765 _ => panic!("Expected Float result"),
766 }
767 }
768
769 #[test]
770 fn test_string_agg() {
771 let func = StringAggFunction;
772 let mut state = func.init();
773
774 func.accumulate(&mut state, &DataValue::String("apple".to_string()))
775 .unwrap();
776 func.accumulate(&mut state, &DataValue::String("banana".to_string()))
777 .unwrap();
778 func.accumulate(&mut state, &DataValue::Null).unwrap(); func.accumulate(&mut state, &DataValue::String("cherry".to_string()))
780 .unwrap();
781
782 let result = func.finalize(state);
783 assert_eq!(result, DataValue::String("apple,banana,cherry".to_string()));
784 }
785}