1use anyhow::Result;
4
5use super::{
6 AggregateFunction, AggregateState, AvgState, MinMaxState, ModeState, PercentileState,
7 StringAggState, SumState, VarianceState,
8};
9use crate::data::datatable::DataValue;
10
11pub struct CountStarFunction;
13
14impl AggregateFunction for CountStarFunction {
15 fn name(&self) -> &'static str {
16 "COUNT_STAR"
17 }
18
19 fn init(&self) -> AggregateState {
20 AggregateState::Count(0)
21 }
22
23 fn accumulate(&self, state: &mut AggregateState, _value: &DataValue) -> Result<()> {
24 if let AggregateState::Count(ref mut count) = state {
25 *count += 1;
26 }
27 Ok(())
28 }
29
30 fn finalize(&self, state: AggregateState) -> DataValue {
31 if let AggregateState::Count(count) = state {
32 DataValue::Integer(count)
33 } else {
34 DataValue::Null
35 }
36 }
37}
38
39pub struct CountFunction;
41
42impl AggregateFunction for CountFunction {
43 fn name(&self) -> &'static str {
44 "COUNT"
45 }
46
47 fn init(&self) -> AggregateState {
48 AggregateState::Count(0)
49 }
50
51 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
52 if let AggregateState::Count(ref mut count) = state {
53 if !matches!(value, DataValue::Null) {
54 *count += 1;
55 }
56 }
57 Ok(())
58 }
59
60 fn finalize(&self, state: AggregateState) -> DataValue {
61 if let AggregateState::Count(count) = state {
62 DataValue::Integer(count)
63 } else {
64 DataValue::Null
65 }
66 }
67}
68
69pub struct SumFunction;
71
72impl AggregateFunction for SumFunction {
73 fn name(&self) -> &'static str {
74 "SUM"
75 }
76
77 fn init(&self) -> AggregateState {
78 AggregateState::Sum(SumState::new())
79 }
80
81 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
82 if let AggregateState::Sum(ref mut sum_state) = state {
83 sum_state.add(value)?;
84 }
85 Ok(())
86 }
87
88 fn finalize(&self, state: AggregateState) -> DataValue {
89 if let AggregateState::Sum(sum_state) = state {
90 sum_state.finalize()
91 } else {
92 DataValue::Null
93 }
94 }
95
96 fn requires_numeric(&self) -> bool {
97 true
98 }
99}
100
101pub struct AvgFunction;
103
104impl AggregateFunction for AvgFunction {
105 fn name(&self) -> &'static str {
106 "AVG"
107 }
108
109 fn init(&self) -> AggregateState {
110 AggregateState::Avg(AvgState::new())
111 }
112
113 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
114 if let AggregateState::Avg(ref mut avg_state) = state {
115 avg_state.add(value)?;
116 }
117 Ok(())
118 }
119
120 fn finalize(&self, state: AggregateState) -> DataValue {
121 if let AggregateState::Avg(avg_state) = state {
122 avg_state.finalize()
123 } else {
124 DataValue::Null
125 }
126 }
127
128 fn requires_numeric(&self) -> bool {
129 true
130 }
131}
132
133pub struct MinFunction;
135
136impl AggregateFunction for MinFunction {
137 fn name(&self) -> &'static str {
138 "MIN"
139 }
140
141 fn init(&self) -> AggregateState {
142 AggregateState::MinMax(MinMaxState::new(true))
143 }
144
145 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
146 if let AggregateState::MinMax(ref mut minmax_state) = state {
147 minmax_state.add(value)?;
148 }
149 Ok(())
150 }
151
152 fn finalize(&self, state: AggregateState) -> DataValue {
153 if let AggregateState::MinMax(minmax_state) = state {
154 minmax_state.finalize()
155 } else {
156 DataValue::Null
157 }
158 }
159}
160
161pub struct MaxFunction;
163
164impl AggregateFunction for MaxFunction {
165 fn name(&self) -> &'static str {
166 "MAX"
167 }
168
169 fn init(&self) -> AggregateState {
170 AggregateState::MinMax(MinMaxState::new(false))
171 }
172
173 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
174 if let AggregateState::MinMax(ref mut minmax_state) = state {
175 minmax_state.add(value)?;
176 }
177 Ok(())
178 }
179
180 fn finalize(&self, state: AggregateState) -> DataValue {
181 if let AggregateState::MinMax(minmax_state) = state {
182 minmax_state.finalize()
183 } else {
184 DataValue::Null
185 }
186 }
187}
188
189pub struct VarianceFunction;
191
192impl AggregateFunction for VarianceFunction {
193 fn name(&self) -> &'static str {
194 "VARIANCE"
195 }
196
197 fn init(&self) -> AggregateState {
198 AggregateState::Variance(VarianceState::new())
199 }
200
201 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
202 if let AggregateState::Variance(ref mut var_state) = state {
203 var_state.add(value)?;
204 }
205 Ok(())
206 }
207
208 fn finalize(&self, state: AggregateState) -> DataValue {
209 if let AggregateState::Variance(var_state) = state {
210 var_state.finalize_variance()
211 } else {
212 DataValue::Null
213 }
214 }
215
216 fn requires_numeric(&self) -> bool {
217 true
218 }
219}
220
221pub struct StdDevFunction;
223
224impl AggregateFunction for StdDevFunction {
225 fn name(&self) -> &'static str {
226 "STDDEV"
227 }
228
229 fn init(&self) -> AggregateState {
230 AggregateState::Variance(VarianceState::new())
231 }
232
233 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
234 if let AggregateState::Variance(ref mut var_state) = state {
235 var_state.add(value)?;
236 }
237 Ok(())
238 }
239
240 fn finalize(&self, state: AggregateState) -> DataValue {
241 if let AggregateState::Variance(var_state) = state {
242 var_state.finalize_stddev()
243 } else {
244 DataValue::Null
245 }
246 }
247
248 fn requires_numeric(&self) -> bool {
249 true
250 }
251}
252
253pub struct StringAggFunction;
255
256impl AggregateFunction for StringAggFunction {
257 fn name(&self) -> &'static str {
258 "STRING_AGG"
259 }
260
261 fn init(&self) -> AggregateState {
262 AggregateState::StringAgg(StringAggState::new(","))
263 }
264
265 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
266 if let AggregateState::StringAgg(ref mut agg_state) = state {
267 agg_state.add(value)?;
268 }
269 Ok(())
270 }
271
272 fn finalize(&self, state: AggregateState) -> DataValue {
273 if let AggregateState::StringAgg(agg_state) = state {
274 agg_state.finalize()
275 } else {
276 DataValue::Null
277 }
278 }
279}
280
281pub struct MedianFunction;
283
284impl AggregateFunction for MedianFunction {
285 fn name(&self) -> &'static str {
286 "MEDIAN"
287 }
288
289 fn init(&self) -> AggregateState {
290 AggregateState::CollectList(Vec::new())
291 }
292
293 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
294 if let AggregateState::CollectList(ref mut values) = state {
295 if !matches!(value, DataValue::Null) {
297 values.push(value.clone());
298 }
299 }
300 Ok(())
301 }
302
303 fn finalize(&self, state: AggregateState) -> DataValue {
304 if let AggregateState::CollectList(mut values) = state {
305 if values.is_empty() {
306 return DataValue::Null;
307 }
308
309 values.sort_by(|a, b| {
311 use std::cmp::Ordering;
312 match (a, b) {
313 (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
314 (DataValue::Float(a), DataValue::Float(b)) => {
315 a.partial_cmp(b).unwrap_or(Ordering::Equal)
316 }
317 (DataValue::Integer(a), DataValue::Float(b)) => {
318 (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
319 }
320 (DataValue::Float(a), DataValue::Integer(b)) => {
321 a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
322 }
323 (DataValue::String(a), DataValue::String(b)) => a.cmp(b),
324 (DataValue::InternedString(a), DataValue::InternedString(b)) => a.cmp(b),
325 (DataValue::String(a), DataValue::InternedString(b)) => a.cmp(&**b),
326 (DataValue::InternedString(a), DataValue::String(b)) => (**a).cmp(b),
327 _ => Ordering::Equal,
328 }
329 });
330
331 let len = values.len();
332 if len % 2 == 1 {
333 values[len / 2].clone()
335 } else {
336 let mid1 = &values[len / 2 - 1];
338 let mid2 = &values[len / 2];
339
340 match (mid1, mid2) {
342 (DataValue::Integer(a), DataValue::Integer(b)) => {
343 let avg = (*a + *b) as f64 / 2.0;
344 if avg.fract() == 0.0 {
345 DataValue::Integer(avg as i64)
346 } else {
347 DataValue::Float(avg)
348 }
349 }
350 (DataValue::Float(a), DataValue::Float(b)) => DataValue::Float((a + b) / 2.0),
351 (DataValue::Integer(a), DataValue::Float(b)) => {
352 DataValue::Float((*a as f64 + b) / 2.0)
353 }
354 (DataValue::Float(a), DataValue::Integer(b)) => {
355 DataValue::Float((a + *b as f64) / 2.0)
356 }
357 _ => mid1.clone(),
359 }
360 }
361 } else {
362 DataValue::Null
363 }
364 }
365
366 fn requires_numeric(&self) -> bool {
367 false }
369}
370
371pub struct PercentileFunction;
373
374impl AggregateFunction for PercentileFunction {
375 fn name(&self) -> &'static str {
376 "PERCENTILE"
377 }
378
379 fn init(&self) -> AggregateState {
380 AggregateState::Percentile(PercentileState::new(50.0))
381 }
382
383 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
384 if let AggregateState::Percentile(ref mut percentile_state) = state {
385 percentile_state.add(value)?;
386 }
387 Ok(())
388 }
389
390 fn finalize(&self, state: AggregateState) -> DataValue {
391 if let AggregateState::Percentile(percentile_state) = state {
392 percentile_state.finalize()
393 } else {
394 DataValue::Null
395 }
396 }
397
398 fn requires_numeric(&self) -> bool {
399 true }
401}
402
403pub struct ModeFunction;
405
406impl AggregateFunction for ModeFunction {
407 fn name(&self) -> &'static str {
408 "MODE"
409 }
410
411 fn init(&self) -> AggregateState {
412 AggregateState::Mode(ModeState::new())
413 }
414
415 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
416 if let AggregateState::Mode(ref mut mode_state) = state {
417 mode_state.add(value)?;
418 }
419 Ok(())
420 }
421
422 fn finalize(&self, state: AggregateState) -> DataValue {
423 if let AggregateState::Mode(mode_state) = state {
424 mode_state.finalize()
425 } else {
426 DataValue::Null
427 }
428 }
429
430 fn requires_numeric(&self) -> bool {
431 false }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_count_star() {
441 let func = CountStarFunction;
442 let mut state = func.init();
443
444 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
446 func.accumulate(&mut state, &DataValue::Null).unwrap();
447 func.accumulate(&mut state, &DataValue::String("test".to_string()))
448 .unwrap();
449
450 let result = func.finalize(state);
451 assert_eq!(result, DataValue::Integer(3));
452 }
453
454 #[test]
455 fn test_count_column() {
456 let func = CountFunction;
457 let mut state = func.init();
458
459 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
461 func.accumulate(&mut state, &DataValue::Null).unwrap();
462 func.accumulate(&mut state, &DataValue::String("test".to_string()))
463 .unwrap();
464 func.accumulate(&mut state, &DataValue::Null).unwrap();
465
466 let result = func.finalize(state);
467 assert_eq!(result, DataValue::Integer(2));
468 }
469
470 #[test]
471 fn test_sum_integers() {
472 let func = SumFunction;
473 let mut state = func.init();
474
475 func.accumulate(&mut state, &DataValue::Integer(10))
476 .unwrap();
477 func.accumulate(&mut state, &DataValue::Integer(20))
478 .unwrap();
479 func.accumulate(&mut state, &DataValue::Integer(30))
480 .unwrap();
481 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
484 assert_eq!(result, DataValue::Integer(60));
485 }
486
487 #[test]
488 fn test_sum_mixed() {
489 let func = SumFunction;
490 let mut state = func.init();
491
492 func.accumulate(&mut state, &DataValue::Integer(10))
493 .unwrap();
494 func.accumulate(&mut state, &DataValue::Float(20.5))
495 .unwrap(); func.accumulate(&mut state, &DataValue::Integer(30))
497 .unwrap();
498
499 let result = func.finalize(state);
500 match result {
501 DataValue::Float(f) => assert!((f - 60.5).abs() < 0.001),
502 _ => panic!("Expected Float result"),
503 }
504 }
505
506 #[test]
507 fn test_avg() {
508 let func = AvgFunction;
509 let mut state = func.init();
510
511 func.accumulate(&mut state, &DataValue::Integer(10))
512 .unwrap();
513 func.accumulate(&mut state, &DataValue::Integer(20))
514 .unwrap();
515 func.accumulate(&mut state, &DataValue::Integer(30))
516 .unwrap();
517 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
520 match result {
521 DataValue::Float(f) => assert!((f - 20.0).abs() < 0.001),
522 _ => panic!("Expected Float result"),
523 }
524 }
525
526 #[test]
527 fn test_min() {
528 let func = MinFunction;
529 let mut state = func.init();
530
531 func.accumulate(&mut state, &DataValue::Integer(30))
532 .unwrap();
533 func.accumulate(&mut state, &DataValue::Integer(10))
534 .unwrap();
535 func.accumulate(&mut state, &DataValue::Integer(20))
536 .unwrap();
537 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
540 assert_eq!(result, DataValue::Integer(10));
541 }
542
543 #[test]
544 fn test_max() {
545 let func = MaxFunction;
546 let mut state = func.init();
547
548 func.accumulate(&mut state, &DataValue::Integer(10))
549 .unwrap();
550 func.accumulate(&mut state, &DataValue::Integer(30))
551 .unwrap();
552 func.accumulate(&mut state, &DataValue::Integer(20))
553 .unwrap();
554 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
557 assert_eq!(result, DataValue::Integer(30));
558 }
559
560 #[test]
561 fn test_max_strings() {
562 let func = MaxFunction;
563 let mut state = func.init();
564
565 func.accumulate(&mut state, &DataValue::String("apple".to_string()))
566 .unwrap();
567 func.accumulate(&mut state, &DataValue::String("zebra".to_string()))
568 .unwrap();
569 func.accumulate(&mut state, &DataValue::String("banana".to_string()))
570 .unwrap();
571
572 let result = func.finalize(state);
573 assert_eq!(result, DataValue::String("zebra".to_string()));
574 }
575
576 #[test]
577 fn test_variance() {
578 let func = VarianceFunction;
579 let mut state = func.init();
580
581 func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
584 func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
585 func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
586 func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
587 func.accumulate(&mut state, &DataValue::Integer(10))
588 .unwrap();
589
590 let result = func.finalize(state);
591 match result {
592 DataValue::Float(f) => assert!((f - 8.0).abs() < 0.001),
593 _ => panic!("Expected Float result"),
594 }
595 }
596
597 #[test]
598 fn test_stddev() {
599 let func = StdDevFunction;
600 let mut state = func.init();
601
602 func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
605 func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
606 func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
607 func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
608 func.accumulate(&mut state, &DataValue::Integer(10))
609 .unwrap();
610
611 let result = func.finalize(state);
612 match result {
613 DataValue::Float(f) => assert!((f - 2.8284271247461903).abs() < 0.001),
614 _ => panic!("Expected Float result"),
615 }
616 }
617
618 #[test]
619 fn test_variance_with_nulls() {
620 let func = VarianceFunction;
621 let mut state = func.init();
622
623 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
624 func.accumulate(&mut state, &DataValue::Null).unwrap(); func.accumulate(&mut state, &DataValue::Integer(10))
626 .unwrap();
627 func.accumulate(&mut state, &DataValue::Integer(15))
628 .unwrap();
629
630 let result = func.finalize(state);
631 match result {
632 DataValue::Float(f) => {
633 assert!((f - 16.666666666666668).abs() < 0.001);
636 }
637 _ => panic!("Expected Float result"),
638 }
639 }
640
641 #[test]
642 fn test_string_agg() {
643 let func = StringAggFunction;
644 let mut state = func.init();
645
646 func.accumulate(&mut state, &DataValue::String("apple".to_string()))
647 .unwrap();
648 func.accumulate(&mut state, &DataValue::String("banana".to_string()))
649 .unwrap();
650 func.accumulate(&mut state, &DataValue::Null).unwrap(); func.accumulate(&mut state, &DataValue::String("cherry".to_string()))
652 .unwrap();
653
654 let result = func.finalize(state);
655 assert_eq!(result, DataValue::String("apple,banana,cherry".to_string()));
656 }
657}