sql_cli/sql/aggregates/
functions.rs

1//! Concrete implementations of aggregate functions
2
3use anyhow::Result;
4
5use super::{
6    AggregateFunction, AggregateState, AvgState, MinMaxState, ModeState, PercentileState,
7    StringAggState, SumState, VarianceState,
8};
9use crate::data::datatable::DataValue;
10
11/// COUNT(*) - counts all rows including nulls
12pub struct CountStarFunction;
13
14impl AggregateFunction for CountStarFunction {
15    fn name(&self) -> &'static str {
16        "COUNT_STAR"
17    }
18
19    fn init(&self) -> AggregateState {
20        AggregateState::Count(0)
21    }
22
23    fn accumulate(&self, state: &mut AggregateState, _value: &DataValue) -> Result<()> {
24        if let AggregateState::Count(ref mut count) = state {
25            *count += 1;
26        }
27        Ok(())
28    }
29
30    fn finalize(&self, state: AggregateState) -> DataValue {
31        if let AggregateState::Count(count) = state {
32            DataValue::Integer(count)
33        } else {
34            DataValue::Null
35        }
36    }
37}
38
39/// COUNT(column) - counts non-null values
40pub struct CountFunction;
41
42impl AggregateFunction for CountFunction {
43    fn name(&self) -> &'static str {
44        "COUNT"
45    }
46
47    fn init(&self) -> AggregateState {
48        AggregateState::Count(0)
49    }
50
51    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
52        if let AggregateState::Count(ref mut count) = state {
53            if !matches!(value, DataValue::Null) {
54                *count += 1;
55            }
56        }
57        Ok(())
58    }
59
60    fn finalize(&self, state: AggregateState) -> DataValue {
61        if let AggregateState::Count(count) = state {
62            DataValue::Integer(count)
63        } else {
64            DataValue::Null
65        }
66    }
67}
68
69/// SUM(column) - sums numeric values
70pub struct SumFunction;
71
72impl AggregateFunction for SumFunction {
73    fn name(&self) -> &'static str {
74        "SUM"
75    }
76
77    fn init(&self) -> AggregateState {
78        AggregateState::Sum(SumState::new())
79    }
80
81    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
82        if let AggregateState::Sum(ref mut sum_state) = state {
83            sum_state.add(value)?;
84        }
85        Ok(())
86    }
87
88    fn finalize(&self, state: AggregateState) -> DataValue {
89        if let AggregateState::Sum(sum_state) = state {
90            sum_state.finalize()
91        } else {
92            DataValue::Null
93        }
94    }
95
96    fn requires_numeric(&self) -> bool {
97        true
98    }
99}
100
101/// AVG(column) - averages numeric values
102pub struct AvgFunction;
103
104impl AggregateFunction for AvgFunction {
105    fn name(&self) -> &'static str {
106        "AVG"
107    }
108
109    fn init(&self) -> AggregateState {
110        AggregateState::Avg(AvgState::new())
111    }
112
113    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
114        if let AggregateState::Avg(ref mut avg_state) = state {
115            avg_state.add(value)?;
116        }
117        Ok(())
118    }
119
120    fn finalize(&self, state: AggregateState) -> DataValue {
121        if let AggregateState::Avg(avg_state) = state {
122            avg_state.finalize()
123        } else {
124            DataValue::Null
125        }
126    }
127
128    fn requires_numeric(&self) -> bool {
129        true
130    }
131}
132
133/// MIN(column) - finds minimum value
134pub struct MinFunction;
135
136impl AggregateFunction for MinFunction {
137    fn name(&self) -> &'static str {
138        "MIN"
139    }
140
141    fn init(&self) -> AggregateState {
142        AggregateState::MinMax(MinMaxState::new(true))
143    }
144
145    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
146        if let AggregateState::MinMax(ref mut minmax_state) = state {
147            minmax_state.add(value)?;
148        }
149        Ok(())
150    }
151
152    fn finalize(&self, state: AggregateState) -> DataValue {
153        if let AggregateState::MinMax(minmax_state) = state {
154            minmax_state.finalize()
155        } else {
156            DataValue::Null
157        }
158    }
159}
160
161/// MAX(column) - finds maximum value
162pub struct MaxFunction;
163
164impl AggregateFunction for MaxFunction {
165    fn name(&self) -> &'static str {
166        "MAX"
167    }
168
169    fn init(&self) -> AggregateState {
170        AggregateState::MinMax(MinMaxState::new(false))
171    }
172
173    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
174        if let AggregateState::MinMax(ref mut minmax_state) = state {
175            minmax_state.add(value)?;
176        }
177        Ok(())
178    }
179
180    fn finalize(&self, state: AggregateState) -> DataValue {
181        if let AggregateState::MinMax(minmax_state) = state {
182            minmax_state.finalize()
183        } else {
184            DataValue::Null
185        }
186    }
187}
188
189/// VARIANCE(column) - computes population variance
190pub struct VarianceFunction;
191
192impl AggregateFunction for VarianceFunction {
193    fn name(&self) -> &'static str {
194        "VARIANCE"
195    }
196
197    fn init(&self) -> AggregateState {
198        AggregateState::Variance(VarianceState::new())
199    }
200
201    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
202        if let AggregateState::Variance(ref mut var_state) = state {
203            var_state.add(value)?;
204        }
205        Ok(())
206    }
207
208    fn finalize(&self, state: AggregateState) -> DataValue {
209        if let AggregateState::Variance(var_state) = state {
210            var_state.finalize_variance()
211        } else {
212            DataValue::Null
213        }
214    }
215
216    fn requires_numeric(&self) -> bool {
217        true
218    }
219}
220
221/// STDDEV(column) - computes population standard deviation
222pub struct StdDevFunction;
223
224impl AggregateFunction for StdDevFunction {
225    fn name(&self) -> &'static str {
226        "STDDEV"
227    }
228
229    fn init(&self) -> AggregateState {
230        AggregateState::Variance(VarianceState::new())
231    }
232
233    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
234        if let AggregateState::Variance(ref mut var_state) = state {
235            var_state.add(value)?;
236        }
237        Ok(())
238    }
239
240    fn finalize(&self, state: AggregateState) -> DataValue {
241        if let AggregateState::Variance(var_state) = state {
242            var_state.finalize_stddev()
243        } else {
244            DataValue::Null
245        }
246    }
247
248    fn requires_numeric(&self) -> bool {
249        true
250    }
251}
252
253/// STRING_AGG(column, separator) - concatenates strings with separator
254pub struct StringAggFunction;
255
256impl AggregateFunction for StringAggFunction {
257    fn name(&self) -> &'static str {
258        "STRING_AGG"
259    }
260
261    fn init(&self) -> AggregateState {
262        AggregateState::StringAgg(StringAggState::new(","))
263    }
264
265    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
266        if let AggregateState::StringAgg(ref mut agg_state) = state {
267            agg_state.add(value)?;
268        }
269        Ok(())
270    }
271
272    fn finalize(&self, state: AggregateState) -> DataValue {
273        if let AggregateState::StringAgg(agg_state) = state {
274            agg_state.finalize()
275        } else {
276            DataValue::Null
277        }
278    }
279}
280
281/// MEDIAN(column) - finds the median value
282pub struct MedianFunction;
283
284impl AggregateFunction for MedianFunction {
285    fn name(&self) -> &'static str {
286        "MEDIAN"
287    }
288
289    fn init(&self) -> AggregateState {
290        AggregateState::CollectList(Vec::new())
291    }
292
293    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
294        if let AggregateState::CollectList(ref mut values) = state {
295            // Skip null values
296            if !matches!(value, DataValue::Null) {
297                values.push(value.clone());
298            }
299        }
300        Ok(())
301    }
302
303    fn finalize(&self, state: AggregateState) -> DataValue {
304        if let AggregateState::CollectList(mut values) = state {
305            if values.is_empty() {
306                return DataValue::Null;
307            }
308
309            // Sort values for median calculation
310            values.sort_by(|a, b| {
311                use std::cmp::Ordering;
312                match (a, b) {
313                    (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
314                    (DataValue::Float(a), DataValue::Float(b)) => {
315                        a.partial_cmp(b).unwrap_or(Ordering::Equal)
316                    }
317                    (DataValue::Integer(a), DataValue::Float(b)) => {
318                        (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
319                    }
320                    (DataValue::Float(a), DataValue::Integer(b)) => {
321                        a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
322                    }
323                    (DataValue::String(a), DataValue::String(b)) => a.cmp(b),
324                    (DataValue::InternedString(a), DataValue::InternedString(b)) => a.cmp(b),
325                    (DataValue::String(a), DataValue::InternedString(b)) => a.cmp(&**b),
326                    (DataValue::InternedString(a), DataValue::String(b)) => (**a).cmp(b),
327                    _ => Ordering::Equal,
328                }
329            });
330
331            let len = values.len();
332            if len % 2 == 1 {
333                // Odd number of elements - return middle element
334                values[len / 2].clone()
335            } else {
336                // Even number of elements - return average of two middle elements
337                let mid1 = &values[len / 2 - 1];
338                let mid2 = &values[len / 2];
339
340                // For numeric values, calculate average
341                match (mid1, mid2) {
342                    (DataValue::Integer(a), DataValue::Integer(b)) => {
343                        let avg = (*a + *b) as f64 / 2.0;
344                        if avg.fract() == 0.0 {
345                            DataValue::Integer(avg as i64)
346                        } else {
347                            DataValue::Float(avg)
348                        }
349                    }
350                    (DataValue::Float(a), DataValue::Float(b)) => DataValue::Float((a + b) / 2.0),
351                    (DataValue::Integer(a), DataValue::Float(b)) => {
352                        DataValue::Float((*a as f64 + b) / 2.0)
353                    }
354                    (DataValue::Float(a), DataValue::Integer(b)) => {
355                        DataValue::Float((a + *b as f64) / 2.0)
356                    }
357                    // For non-numeric, return the first middle element
358                    _ => mid1.clone(),
359                }
360            }
361        } else {
362            DataValue::Null
363        }
364    }
365
366    fn requires_numeric(&self) -> bool {
367        false // MEDIAN works on any sortable type
368    }
369}
370
371/// PERCENTILE(column, percentile) - finds the nth percentile value
372pub struct PercentileFunction;
373
374impl AggregateFunction for PercentileFunction {
375    fn name(&self) -> &'static str {
376        "PERCENTILE"
377    }
378
379    fn init(&self) -> AggregateState {
380        AggregateState::Percentile(PercentileState::new(50.0))
381    }
382
383    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
384        if let AggregateState::Percentile(ref mut percentile_state) = state {
385            percentile_state.add(value)?;
386        }
387        Ok(())
388    }
389
390    fn finalize(&self, state: AggregateState) -> DataValue {
391        if let AggregateState::Percentile(percentile_state) = state {
392            percentile_state.finalize()
393        } else {
394            DataValue::Null
395        }
396    }
397
398    fn requires_numeric(&self) -> bool {
399        true // PERCENTILE typically works on numeric data
400    }
401}
402
403/// MODE(column) - finds the most frequently occurring value
404pub struct ModeFunction;
405
406/// STDDEV_POP(column) - computes population standard deviation (same as STDDEV)
407pub struct StdDevPopFunction;
408
409/// STDDEV_SAMP(column) - computes sample standard deviation
410pub struct StdDevSampFunction;
411
412/// VAR_POP(column) - computes population variance (same as VARIANCE)
413pub struct VarPopFunction;
414
415/// VAR_SAMP(column) - computes sample variance
416pub struct VarSampFunction;
417
418impl AggregateFunction for StdDevPopFunction {
419    fn name(&self) -> &'static str {
420        "STDDEV_POP"
421    }
422
423    fn init(&self) -> AggregateState {
424        AggregateState::Variance(VarianceState::new())
425    }
426
427    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
428        if let AggregateState::Variance(ref mut var_state) = state {
429            var_state.add(value)?;
430        }
431        Ok(())
432    }
433
434    fn finalize(&self, state: AggregateState) -> DataValue {
435        if let AggregateState::Variance(var_state) = state {
436            var_state.finalize_stddev()
437        } else {
438            DataValue::Null
439        }
440    }
441
442    fn requires_numeric(&self) -> bool {
443        true
444    }
445}
446
447impl AggregateFunction for StdDevSampFunction {
448    fn name(&self) -> &'static str {
449        "STDDEV_SAMP"
450    }
451
452    fn init(&self) -> AggregateState {
453        AggregateState::Variance(VarianceState::new())
454    }
455
456    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
457        if let AggregateState::Variance(ref mut var_state) = state {
458            var_state.add(value)?;
459        }
460        Ok(())
461    }
462
463    fn finalize(&self, state: AggregateState) -> DataValue {
464        if let AggregateState::Variance(var_state) = state {
465            var_state.finalize_stddev_sample()
466        } else {
467            DataValue::Null
468        }
469    }
470
471    fn requires_numeric(&self) -> bool {
472        true
473    }
474}
475
476impl AggregateFunction for VarPopFunction {
477    fn name(&self) -> &'static str {
478        "VAR_POP"
479    }
480
481    fn init(&self) -> AggregateState {
482        AggregateState::Variance(VarianceState::new())
483    }
484
485    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
486        if let AggregateState::Variance(ref mut var_state) = state {
487            var_state.add(value)?;
488        }
489        Ok(())
490    }
491
492    fn finalize(&self, state: AggregateState) -> DataValue {
493        if let AggregateState::Variance(var_state) = state {
494            var_state.finalize_variance()
495        } else {
496            DataValue::Null
497        }
498    }
499
500    fn requires_numeric(&self) -> bool {
501        true
502    }
503}
504
505impl AggregateFunction for VarSampFunction {
506    fn name(&self) -> &'static str {
507        "VAR_SAMP"
508    }
509
510    fn init(&self) -> AggregateState {
511        AggregateState::Variance(VarianceState::new())
512    }
513
514    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
515        if let AggregateState::Variance(ref mut var_state) = state {
516            var_state.add(value)?;
517        }
518        Ok(())
519    }
520
521    fn finalize(&self, state: AggregateState) -> DataValue {
522        if let AggregateState::Variance(var_state) = state {
523            var_state.finalize_variance_sample()
524        } else {
525            DataValue::Null
526        }
527    }
528
529    fn requires_numeric(&self) -> bool {
530        true
531    }
532}
533
534impl AggregateFunction for ModeFunction {
535    fn name(&self) -> &'static str {
536        "MODE"
537    }
538
539    fn init(&self) -> AggregateState {
540        AggregateState::Mode(ModeState::new())
541    }
542
543    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
544        if let AggregateState::Mode(ref mut mode_state) = state {
545            mode_state.add(value)?;
546        }
547        Ok(())
548    }
549
550    fn finalize(&self, state: AggregateState) -> DataValue {
551        if let AggregateState::Mode(mode_state) = state {
552            mode_state.finalize()
553        } else {
554            DataValue::Null
555        }
556    }
557
558    fn requires_numeric(&self) -> bool {
559        false // MODE works on any data type
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_count_star() {
569        let func = CountStarFunction;
570        let mut state = func.init();
571
572        // COUNT(*) counts everything including nulls
573        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
574        func.accumulate(&mut state, &DataValue::Null).unwrap();
575        func.accumulate(&mut state, &DataValue::String("test".to_string()))
576            .unwrap();
577
578        let result = func.finalize(state);
579        assert_eq!(result, DataValue::Integer(3));
580    }
581
582    #[test]
583    fn test_count_column() {
584        let func = CountFunction;
585        let mut state = func.init();
586
587        // COUNT(column) skips nulls
588        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
589        func.accumulate(&mut state, &DataValue::Null).unwrap();
590        func.accumulate(&mut state, &DataValue::String("test".to_string()))
591            .unwrap();
592        func.accumulate(&mut state, &DataValue::Null).unwrap();
593
594        let result = func.finalize(state);
595        assert_eq!(result, DataValue::Integer(2));
596    }
597
598    #[test]
599    fn test_sum_integers() {
600        let func = SumFunction;
601        let mut state = func.init();
602
603        func.accumulate(&mut state, &DataValue::Integer(10))
604            .unwrap();
605        func.accumulate(&mut state, &DataValue::Integer(20))
606            .unwrap();
607        func.accumulate(&mut state, &DataValue::Integer(30))
608            .unwrap();
609        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
610
611        let result = func.finalize(state);
612        assert_eq!(result, DataValue::Integer(60));
613    }
614
615    #[test]
616    fn test_sum_mixed() {
617        let func = SumFunction;
618        let mut state = func.init();
619
620        func.accumulate(&mut state, &DataValue::Integer(10))
621            .unwrap();
622        func.accumulate(&mut state, &DataValue::Float(20.5))
623            .unwrap(); // Converts to float
624        func.accumulate(&mut state, &DataValue::Integer(30))
625            .unwrap();
626
627        let result = func.finalize(state);
628        match result {
629            DataValue::Float(f) => assert!((f - 60.5).abs() < 0.001),
630            _ => panic!("Expected Float result"),
631        }
632    }
633
634    #[test]
635    fn test_avg() {
636        let func = AvgFunction;
637        let mut state = func.init();
638
639        func.accumulate(&mut state, &DataValue::Integer(10))
640            .unwrap();
641        func.accumulate(&mut state, &DataValue::Integer(20))
642            .unwrap();
643        func.accumulate(&mut state, &DataValue::Integer(30))
644            .unwrap();
645        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
646
647        let result = func.finalize(state);
648        match result {
649            DataValue::Float(f) => assert!((f - 20.0).abs() < 0.001),
650            _ => panic!("Expected Float result"),
651        }
652    }
653
654    #[test]
655    fn test_min() {
656        let func = MinFunction;
657        let mut state = func.init();
658
659        func.accumulate(&mut state, &DataValue::Integer(30))
660            .unwrap();
661        func.accumulate(&mut state, &DataValue::Integer(10))
662            .unwrap();
663        func.accumulate(&mut state, &DataValue::Integer(20))
664            .unwrap();
665        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
666
667        let result = func.finalize(state);
668        assert_eq!(result, DataValue::Integer(10));
669    }
670
671    #[test]
672    fn test_max() {
673        let func = MaxFunction;
674        let mut state = func.init();
675
676        func.accumulate(&mut state, &DataValue::Integer(10))
677            .unwrap();
678        func.accumulate(&mut state, &DataValue::Integer(30))
679            .unwrap();
680        func.accumulate(&mut state, &DataValue::Integer(20))
681            .unwrap();
682        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
683
684        let result = func.finalize(state);
685        assert_eq!(result, DataValue::Integer(30));
686    }
687
688    #[test]
689    fn test_max_strings() {
690        let func = MaxFunction;
691        let mut state = func.init();
692
693        func.accumulate(&mut state, &DataValue::String("apple".to_string()))
694            .unwrap();
695        func.accumulate(&mut state, &DataValue::String("zebra".to_string()))
696            .unwrap();
697        func.accumulate(&mut state, &DataValue::String("banana".to_string()))
698            .unwrap();
699
700        let result = func.finalize(state);
701        assert_eq!(result, DataValue::String("zebra".to_string()));
702    }
703
704    #[test]
705    fn test_variance() {
706        let func = VarianceFunction;
707        let mut state = func.init();
708
709        // Test data: [2, 4, 6, 8, 10]
710        // Mean = 6, Variance = 8
711        func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
712        func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
713        func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
714        func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
715        func.accumulate(&mut state, &DataValue::Integer(10))
716            .unwrap();
717
718        let result = func.finalize(state);
719        match result {
720            DataValue::Float(f) => assert!((f - 8.0).abs() < 0.001),
721            _ => panic!("Expected Float result"),
722        }
723    }
724
725    #[test]
726    fn test_stddev() {
727        let func = StdDevFunction;
728        let mut state = func.init();
729
730        // Test data: [2, 4, 6, 8, 10]
731        // Mean = 6, Variance = 8, StdDev = sqrt(8) ≈ 2.828
732        func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
733        func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
734        func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
735        func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
736        func.accumulate(&mut state, &DataValue::Integer(10))
737            .unwrap();
738
739        let result = func.finalize(state);
740        match result {
741            DataValue::Float(f) => assert!((f - 2.8284271247461903).abs() < 0.001),
742            _ => panic!("Expected Float result"),
743        }
744    }
745
746    #[test]
747    fn test_variance_with_nulls() {
748        let func = VarianceFunction;
749        let mut state = func.init();
750
751        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
752        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Should be ignored
753        func.accumulate(&mut state, &DataValue::Integer(10))
754            .unwrap();
755        func.accumulate(&mut state, &DataValue::Integer(15))
756            .unwrap();
757
758        let result = func.finalize(state);
759        match result {
760            DataValue::Float(f) => {
761                // Mean = 10, values = [5, 10, 15]
762                // Variance = ((5-10)² + (10-10)² + (15-10)²) / 3 = (25 + 0 + 25) / 3 ≈ 16.67
763                assert!((f - 16.666666666666668).abs() < 0.001);
764            }
765            _ => panic!("Expected Float result"),
766        }
767    }
768
769    #[test]
770    fn test_string_agg() {
771        let func = StringAggFunction;
772        let mut state = func.init();
773
774        func.accumulate(&mut state, &DataValue::String("apple".to_string()))
775            .unwrap();
776        func.accumulate(&mut state, &DataValue::String("banana".to_string()))
777            .unwrap();
778        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Should be ignored
779        func.accumulate(&mut state, &DataValue::String("cherry".to_string()))
780            .unwrap();
781
782        let result = func.finalize(state);
783        assert_eq!(result, DataValue::String("apple,banana,cherry".to_string()));
784    }
785}