sql_cli/sql/aggregates/
functions.rs

1//! Concrete implementations of aggregate functions
2
3use anyhow::Result;
4
5use super::{
6    AggregateFunction, AggregateState, AvgState, MinMaxState, ModeState, PercentileState,
7    StringAggState, SumState, VarianceState,
8};
9use crate::data::datatable::DataValue;
10
11/// COUNT(*) - counts all rows including nulls
12pub struct CountStarFunction;
13
14impl AggregateFunction for CountStarFunction {
15    fn name(&self) -> &'static str {
16        "COUNT_STAR"
17    }
18
19    fn init(&self) -> AggregateState {
20        AggregateState::Count(0)
21    }
22
23    fn accumulate(&self, state: &mut AggregateState, _value: &DataValue) -> Result<()> {
24        if let AggregateState::Count(ref mut count) = state {
25            *count += 1;
26        }
27        Ok(())
28    }
29
30    fn finalize(&self, state: AggregateState) -> DataValue {
31        if let AggregateState::Count(count) = state {
32            DataValue::Integer(count)
33        } else {
34            DataValue::Null
35        }
36    }
37}
38
39/// COUNT(column) - counts non-null values
40pub struct CountFunction;
41
42impl AggregateFunction for CountFunction {
43    fn name(&self) -> &'static str {
44        "COUNT"
45    }
46
47    fn init(&self) -> AggregateState {
48        AggregateState::Count(0)
49    }
50
51    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
52        if let AggregateState::Count(ref mut count) = state {
53            if !matches!(value, DataValue::Null) {
54                *count += 1;
55            }
56        }
57        Ok(())
58    }
59
60    fn finalize(&self, state: AggregateState) -> DataValue {
61        if let AggregateState::Count(count) = state {
62            DataValue::Integer(count)
63        } else {
64            DataValue::Null
65        }
66    }
67}
68
69/// SUM(column) - sums numeric values
70pub struct SumFunction;
71
72impl AggregateFunction for SumFunction {
73    fn name(&self) -> &'static str {
74        "SUM"
75    }
76
77    fn init(&self) -> AggregateState {
78        AggregateState::Sum(SumState::new())
79    }
80
81    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
82        if let AggregateState::Sum(ref mut sum_state) = state {
83            sum_state.add(value)?;
84        }
85        Ok(())
86    }
87
88    fn finalize(&self, state: AggregateState) -> DataValue {
89        if let AggregateState::Sum(sum_state) = state {
90            sum_state.finalize()
91        } else {
92            DataValue::Null
93        }
94    }
95
96    fn requires_numeric(&self) -> bool {
97        true
98    }
99}
100
101/// AVG(column) - averages numeric values
102pub struct AvgFunction;
103
104impl AggregateFunction for AvgFunction {
105    fn name(&self) -> &'static str {
106        "AVG"
107    }
108
109    fn init(&self) -> AggregateState {
110        AggregateState::Avg(AvgState::new())
111    }
112
113    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
114        if let AggregateState::Avg(ref mut avg_state) = state {
115            avg_state.add(value)?;
116        }
117        Ok(())
118    }
119
120    fn finalize(&self, state: AggregateState) -> DataValue {
121        if let AggregateState::Avg(avg_state) = state {
122            avg_state.finalize()
123        } else {
124            DataValue::Null
125        }
126    }
127
128    fn requires_numeric(&self) -> bool {
129        true
130    }
131}
132
133/// MIN(column) - finds minimum value
134pub struct MinFunction;
135
136impl AggregateFunction for MinFunction {
137    fn name(&self) -> &'static str {
138        "MIN"
139    }
140
141    fn init(&self) -> AggregateState {
142        AggregateState::MinMax(MinMaxState::new(true))
143    }
144
145    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
146        if let AggregateState::MinMax(ref mut minmax_state) = state {
147            minmax_state.add(value)?;
148        }
149        Ok(())
150    }
151
152    fn finalize(&self, state: AggregateState) -> DataValue {
153        if let AggregateState::MinMax(minmax_state) = state {
154            minmax_state.finalize()
155        } else {
156            DataValue::Null
157        }
158    }
159}
160
161/// MAX(column) - finds maximum value
162pub struct MaxFunction;
163
164impl AggregateFunction for MaxFunction {
165    fn name(&self) -> &'static str {
166        "MAX"
167    }
168
169    fn init(&self) -> AggregateState {
170        AggregateState::MinMax(MinMaxState::new(false))
171    }
172
173    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
174        if let AggregateState::MinMax(ref mut minmax_state) = state {
175            minmax_state.add(value)?;
176        }
177        Ok(())
178    }
179
180    fn finalize(&self, state: AggregateState) -> DataValue {
181        if let AggregateState::MinMax(minmax_state) = state {
182            minmax_state.finalize()
183        } else {
184            DataValue::Null
185        }
186    }
187}
188
189/// VARIANCE(column) - computes population variance
190pub struct VarianceFunction;
191
192impl AggregateFunction for VarianceFunction {
193    fn name(&self) -> &'static str {
194        "VARIANCE"
195    }
196
197    fn init(&self) -> AggregateState {
198        AggregateState::Variance(VarianceState::new())
199    }
200
201    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
202        if let AggregateState::Variance(ref mut var_state) = state {
203            var_state.add(value)?;
204        }
205        Ok(())
206    }
207
208    fn finalize(&self, state: AggregateState) -> DataValue {
209        if let AggregateState::Variance(var_state) = state {
210            var_state.finalize_variance()
211        } else {
212            DataValue::Null
213        }
214    }
215
216    fn requires_numeric(&self) -> bool {
217        true
218    }
219}
220
221/// STDDEV(column) - computes population standard deviation
222pub struct StdDevFunction;
223
224impl AggregateFunction for StdDevFunction {
225    fn name(&self) -> &'static str {
226        "STDDEV"
227    }
228
229    fn init(&self) -> AggregateState {
230        AggregateState::Variance(VarianceState::new())
231    }
232
233    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
234        if let AggregateState::Variance(ref mut var_state) = state {
235            var_state.add(value)?;
236        }
237        Ok(())
238    }
239
240    fn finalize(&self, state: AggregateState) -> DataValue {
241        if let AggregateState::Variance(var_state) = state {
242            var_state.finalize_stddev()
243        } else {
244            DataValue::Null
245        }
246    }
247
248    fn requires_numeric(&self) -> bool {
249        true
250    }
251}
252
253/// STRING_AGG(column, separator) - concatenates strings with separator
254pub struct StringAggFunction;
255
256impl AggregateFunction for StringAggFunction {
257    fn name(&self) -> &'static str {
258        "STRING_AGG"
259    }
260
261    fn init(&self) -> AggregateState {
262        AggregateState::StringAgg(StringAggState::new(","))
263    }
264
265    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
266        if let AggregateState::StringAgg(ref mut agg_state) = state {
267            agg_state.add(value)?;
268        }
269        Ok(())
270    }
271
272    fn finalize(&self, state: AggregateState) -> DataValue {
273        if let AggregateState::StringAgg(agg_state) = state {
274            agg_state.finalize()
275        } else {
276            DataValue::Null
277        }
278    }
279}
280
281/// MEDIAN(column) - finds the median value
282pub struct MedianFunction;
283
284impl AggregateFunction for MedianFunction {
285    fn name(&self) -> &'static str {
286        "MEDIAN"
287    }
288
289    fn init(&self) -> AggregateState {
290        AggregateState::CollectList(Vec::new())
291    }
292
293    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
294        if let AggregateState::CollectList(ref mut values) = state {
295            // Skip null values
296            if !matches!(value, DataValue::Null) {
297                values.push(value.clone());
298            }
299        }
300        Ok(())
301    }
302
303    fn finalize(&self, state: AggregateState) -> DataValue {
304        if let AggregateState::CollectList(mut values) = state {
305            if values.is_empty() {
306                return DataValue::Null;
307            }
308
309            // Sort values for median calculation
310            values.sort_by(|a, b| {
311                use std::cmp::Ordering;
312                match (a, b) {
313                    (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
314                    (DataValue::Float(a), DataValue::Float(b)) => {
315                        a.partial_cmp(b).unwrap_or(Ordering::Equal)
316                    }
317                    (DataValue::Integer(a), DataValue::Float(b)) => {
318                        (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
319                    }
320                    (DataValue::Float(a), DataValue::Integer(b)) => {
321                        a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
322                    }
323                    (DataValue::String(a), DataValue::String(b)) => a.cmp(b),
324                    (DataValue::InternedString(a), DataValue::InternedString(b)) => a.cmp(b),
325                    (DataValue::String(a), DataValue::InternedString(b)) => a.cmp(&**b),
326                    (DataValue::InternedString(a), DataValue::String(b)) => (**a).cmp(b),
327                    _ => Ordering::Equal,
328                }
329            });
330
331            let len = values.len();
332            if len % 2 == 1 {
333                // Odd number of elements - return middle element
334                values[len / 2].clone()
335            } else {
336                // Even number of elements - return average of two middle elements
337                let mid1 = &values[len / 2 - 1];
338                let mid2 = &values[len / 2];
339
340                // For numeric values, calculate average
341                match (mid1, mid2) {
342                    (DataValue::Integer(a), DataValue::Integer(b)) => {
343                        let avg = (*a + *b) as f64 / 2.0;
344                        if avg.fract() == 0.0 {
345                            DataValue::Integer(avg as i64)
346                        } else {
347                            DataValue::Float(avg)
348                        }
349                    }
350                    (DataValue::Float(a), DataValue::Float(b)) => DataValue::Float((a + b) / 2.0),
351                    (DataValue::Integer(a), DataValue::Float(b)) => {
352                        DataValue::Float((*a as f64 + b) / 2.0)
353                    }
354                    (DataValue::Float(a), DataValue::Integer(b)) => {
355                        DataValue::Float((a + *b as f64) / 2.0)
356                    }
357                    // For non-numeric, return the first middle element
358                    _ => mid1.clone(),
359                }
360            }
361        } else {
362            DataValue::Null
363        }
364    }
365
366    fn requires_numeric(&self) -> bool {
367        false // MEDIAN works on any sortable type
368    }
369}
370
371/// PERCENTILE(column, percentile) - finds the nth percentile value
372pub struct PercentileFunction;
373
374impl AggregateFunction for PercentileFunction {
375    fn name(&self) -> &'static str {
376        "PERCENTILE"
377    }
378
379    fn init(&self) -> AggregateState {
380        AggregateState::Percentile(PercentileState::new(50.0))
381    }
382
383    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
384        if let AggregateState::Percentile(ref mut percentile_state) = state {
385            percentile_state.add(value)?;
386        }
387        Ok(())
388    }
389
390    fn finalize(&self, state: AggregateState) -> DataValue {
391        if let AggregateState::Percentile(percentile_state) = state {
392            percentile_state.finalize()
393        } else {
394            DataValue::Null
395        }
396    }
397
398    fn requires_numeric(&self) -> bool {
399        true // PERCENTILE typically works on numeric data
400    }
401}
402
403/// MODE(column) - finds the most frequently occurring value
404pub struct ModeFunction;
405
406impl AggregateFunction for ModeFunction {
407    fn name(&self) -> &'static str {
408        "MODE"
409    }
410
411    fn init(&self) -> AggregateState {
412        AggregateState::Mode(ModeState::new())
413    }
414
415    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
416        if let AggregateState::Mode(ref mut mode_state) = state {
417            mode_state.add(value)?;
418        }
419        Ok(())
420    }
421
422    fn finalize(&self, state: AggregateState) -> DataValue {
423        if let AggregateState::Mode(mode_state) = state {
424            mode_state.finalize()
425        } else {
426            DataValue::Null
427        }
428    }
429
430    fn requires_numeric(&self) -> bool {
431        false // MODE works on any data type
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_count_star() {
441        let func = CountStarFunction;
442        let mut state = func.init();
443
444        // COUNT(*) counts everything including nulls
445        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
446        func.accumulate(&mut state, &DataValue::Null).unwrap();
447        func.accumulate(&mut state, &DataValue::String("test".to_string()))
448            .unwrap();
449
450        let result = func.finalize(state);
451        assert_eq!(result, DataValue::Integer(3));
452    }
453
454    #[test]
455    fn test_count_column() {
456        let func = CountFunction;
457        let mut state = func.init();
458
459        // COUNT(column) skips nulls
460        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
461        func.accumulate(&mut state, &DataValue::Null).unwrap();
462        func.accumulate(&mut state, &DataValue::String("test".to_string()))
463            .unwrap();
464        func.accumulate(&mut state, &DataValue::Null).unwrap();
465
466        let result = func.finalize(state);
467        assert_eq!(result, DataValue::Integer(2));
468    }
469
470    #[test]
471    fn test_sum_integers() {
472        let func = SumFunction;
473        let mut state = func.init();
474
475        func.accumulate(&mut state, &DataValue::Integer(10))
476            .unwrap();
477        func.accumulate(&mut state, &DataValue::Integer(20))
478            .unwrap();
479        func.accumulate(&mut state, &DataValue::Integer(30))
480            .unwrap();
481        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
482
483        let result = func.finalize(state);
484        assert_eq!(result, DataValue::Integer(60));
485    }
486
487    #[test]
488    fn test_sum_mixed() {
489        let func = SumFunction;
490        let mut state = func.init();
491
492        func.accumulate(&mut state, &DataValue::Integer(10))
493            .unwrap();
494        func.accumulate(&mut state, &DataValue::Float(20.5))
495            .unwrap(); // Converts to float
496        func.accumulate(&mut state, &DataValue::Integer(30))
497            .unwrap();
498
499        let result = func.finalize(state);
500        match result {
501            DataValue::Float(f) => assert!((f - 60.5).abs() < 0.001),
502            _ => panic!("Expected Float result"),
503        }
504    }
505
506    #[test]
507    fn test_avg() {
508        let func = AvgFunction;
509        let mut state = func.init();
510
511        func.accumulate(&mut state, &DataValue::Integer(10))
512            .unwrap();
513        func.accumulate(&mut state, &DataValue::Integer(20))
514            .unwrap();
515        func.accumulate(&mut state, &DataValue::Integer(30))
516            .unwrap();
517        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
518
519        let result = func.finalize(state);
520        match result {
521            DataValue::Float(f) => assert!((f - 20.0).abs() < 0.001),
522            _ => panic!("Expected Float result"),
523        }
524    }
525
526    #[test]
527    fn test_min() {
528        let func = MinFunction;
529        let mut state = func.init();
530
531        func.accumulate(&mut state, &DataValue::Integer(30))
532            .unwrap();
533        func.accumulate(&mut state, &DataValue::Integer(10))
534            .unwrap();
535        func.accumulate(&mut state, &DataValue::Integer(20))
536            .unwrap();
537        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
538
539        let result = func.finalize(state);
540        assert_eq!(result, DataValue::Integer(10));
541    }
542
543    #[test]
544    fn test_max() {
545        let func = MaxFunction;
546        let mut state = func.init();
547
548        func.accumulate(&mut state, &DataValue::Integer(10))
549            .unwrap();
550        func.accumulate(&mut state, &DataValue::Integer(30))
551            .unwrap();
552        func.accumulate(&mut state, &DataValue::Integer(20))
553            .unwrap();
554        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
555
556        let result = func.finalize(state);
557        assert_eq!(result, DataValue::Integer(30));
558    }
559
560    #[test]
561    fn test_max_strings() {
562        let func = MaxFunction;
563        let mut state = func.init();
564
565        func.accumulate(&mut state, &DataValue::String("apple".to_string()))
566            .unwrap();
567        func.accumulate(&mut state, &DataValue::String("zebra".to_string()))
568            .unwrap();
569        func.accumulate(&mut state, &DataValue::String("banana".to_string()))
570            .unwrap();
571
572        let result = func.finalize(state);
573        assert_eq!(result, DataValue::String("zebra".to_string()));
574    }
575
576    #[test]
577    fn test_variance() {
578        let func = VarianceFunction;
579        let mut state = func.init();
580
581        // Test data: [2, 4, 6, 8, 10]
582        // Mean = 6, Variance = 8
583        func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
584        func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
585        func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
586        func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
587        func.accumulate(&mut state, &DataValue::Integer(10))
588            .unwrap();
589
590        let result = func.finalize(state);
591        match result {
592            DataValue::Float(f) => assert!((f - 8.0).abs() < 0.001),
593            _ => panic!("Expected Float result"),
594        }
595    }
596
597    #[test]
598    fn test_stddev() {
599        let func = StdDevFunction;
600        let mut state = func.init();
601
602        // Test data: [2, 4, 6, 8, 10]
603        // Mean = 6, Variance = 8, StdDev = sqrt(8) ≈ 2.828
604        func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
605        func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
606        func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
607        func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
608        func.accumulate(&mut state, &DataValue::Integer(10))
609            .unwrap();
610
611        let result = func.finalize(state);
612        match result {
613            DataValue::Float(f) => assert!((f - 2.8284271247461903).abs() < 0.001),
614            _ => panic!("Expected Float result"),
615        }
616    }
617
618    #[test]
619    fn test_variance_with_nulls() {
620        let func = VarianceFunction;
621        let mut state = func.init();
622
623        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
624        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Should be ignored
625        func.accumulate(&mut state, &DataValue::Integer(10))
626            .unwrap();
627        func.accumulate(&mut state, &DataValue::Integer(15))
628            .unwrap();
629
630        let result = func.finalize(state);
631        match result {
632            DataValue::Float(f) => {
633                // Mean = 10, values = [5, 10, 15]
634                // Variance = ((5-10)² + (10-10)² + (15-10)²) / 3 = (25 + 0 + 25) / 3 ≈ 16.67
635                assert!((f - 16.666666666666668).abs() < 0.001);
636            }
637            _ => panic!("Expected Float result"),
638        }
639    }
640
641    #[test]
642    fn test_string_agg() {
643        let func = StringAggFunction;
644        let mut state = func.init();
645
646        func.accumulate(&mut state, &DataValue::String("apple".to_string()))
647            .unwrap();
648        func.accumulate(&mut state, &DataValue::String("banana".to_string()))
649            .unwrap();
650        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Should be ignored
651        func.accumulate(&mut state, &DataValue::String("cherry".to_string()))
652            .unwrap();
653
654        let result = func.finalize(state);
655        assert_eq!(result, DataValue::String("apple,banana,cherry".to_string()));
656    }
657}