sql_cli/sql/aggregates/
functions.rs

1//! Concrete implementations of aggregate functions
2
3use anyhow::Result;
4
5use super::{AggregateFunction, AggregateState, AvgState, MinMaxState, SumState, VarianceState};
6use crate::data::datatable::DataValue;
7
8/// COUNT(*) - counts all rows including nulls
9pub struct CountStarFunction;
10
11impl AggregateFunction for CountStarFunction {
12    fn name(&self) -> &'static str {
13        "COUNT_STAR"
14    }
15
16    fn init(&self) -> AggregateState {
17        AggregateState::Count(0)
18    }
19
20    fn accumulate(&self, state: &mut AggregateState, _value: &DataValue) -> Result<()> {
21        if let AggregateState::Count(ref mut count) = state {
22            *count += 1;
23        }
24        Ok(())
25    }
26
27    fn finalize(&self, state: AggregateState) -> DataValue {
28        if let AggregateState::Count(count) = state {
29            DataValue::Integer(count)
30        } else {
31            DataValue::Null
32        }
33    }
34}
35
36/// COUNT(column) - counts non-null values
37pub struct CountFunction;
38
39impl AggregateFunction for CountFunction {
40    fn name(&self) -> &'static str {
41        "COUNT"
42    }
43
44    fn init(&self) -> AggregateState {
45        AggregateState::Count(0)
46    }
47
48    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
49        if let AggregateState::Count(ref mut count) = state {
50            if !matches!(value, DataValue::Null) {
51                *count += 1;
52            }
53        }
54        Ok(())
55    }
56
57    fn finalize(&self, state: AggregateState) -> DataValue {
58        if let AggregateState::Count(count) = state {
59            DataValue::Integer(count)
60        } else {
61            DataValue::Null
62        }
63    }
64}
65
66/// SUM(column) - sums numeric values
67pub struct SumFunction;
68
69impl AggregateFunction for SumFunction {
70    fn name(&self) -> &'static str {
71        "SUM"
72    }
73
74    fn init(&self) -> AggregateState {
75        AggregateState::Sum(SumState::new())
76    }
77
78    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
79        if let AggregateState::Sum(ref mut sum_state) = state {
80            sum_state.add(value)?;
81        }
82        Ok(())
83    }
84
85    fn finalize(&self, state: AggregateState) -> DataValue {
86        if let AggregateState::Sum(sum_state) = state {
87            sum_state.finalize()
88        } else {
89            DataValue::Null
90        }
91    }
92
93    fn requires_numeric(&self) -> bool {
94        true
95    }
96}
97
98/// AVG(column) - averages numeric values
99pub struct AvgFunction;
100
101impl AggregateFunction for AvgFunction {
102    fn name(&self) -> &'static str {
103        "AVG"
104    }
105
106    fn init(&self) -> AggregateState {
107        AggregateState::Avg(AvgState::new())
108    }
109
110    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
111        if let AggregateState::Avg(ref mut avg_state) = state {
112            avg_state.add(value)?;
113        }
114        Ok(())
115    }
116
117    fn finalize(&self, state: AggregateState) -> DataValue {
118        if let AggregateState::Avg(avg_state) = state {
119            avg_state.finalize()
120        } else {
121            DataValue::Null
122        }
123    }
124
125    fn requires_numeric(&self) -> bool {
126        true
127    }
128}
129
130/// MIN(column) - finds minimum value
131pub struct MinFunction;
132
133impl AggregateFunction for MinFunction {
134    fn name(&self) -> &'static str {
135        "MIN"
136    }
137
138    fn init(&self) -> AggregateState {
139        AggregateState::MinMax(MinMaxState::new(true))
140    }
141
142    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
143        if let AggregateState::MinMax(ref mut minmax_state) = state {
144            minmax_state.add(value)?;
145        }
146        Ok(())
147    }
148
149    fn finalize(&self, state: AggregateState) -> DataValue {
150        if let AggregateState::MinMax(minmax_state) = state {
151            minmax_state.finalize()
152        } else {
153            DataValue::Null
154        }
155    }
156}
157
158/// MAX(column) - finds maximum value
159pub struct MaxFunction;
160
161impl AggregateFunction for MaxFunction {
162    fn name(&self) -> &'static str {
163        "MAX"
164    }
165
166    fn init(&self) -> AggregateState {
167        AggregateState::MinMax(MinMaxState::new(false))
168    }
169
170    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
171        if let AggregateState::MinMax(ref mut minmax_state) = state {
172            minmax_state.add(value)?;
173        }
174        Ok(())
175    }
176
177    fn finalize(&self, state: AggregateState) -> DataValue {
178        if let AggregateState::MinMax(minmax_state) = state {
179            minmax_state.finalize()
180        } else {
181            DataValue::Null
182        }
183    }
184}
185
186/// VARIANCE(column) - computes population variance
187pub struct VarianceFunction;
188
189impl AggregateFunction for VarianceFunction {
190    fn name(&self) -> &'static str {
191        "VARIANCE"
192    }
193
194    fn init(&self) -> AggregateState {
195        AggregateState::Variance(VarianceState::new())
196    }
197
198    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
199        if let AggregateState::Variance(ref mut var_state) = state {
200            var_state.add(value)?;
201        }
202        Ok(())
203    }
204
205    fn finalize(&self, state: AggregateState) -> DataValue {
206        if let AggregateState::Variance(var_state) = state {
207            var_state.finalize_variance()
208        } else {
209            DataValue::Null
210        }
211    }
212
213    fn requires_numeric(&self) -> bool {
214        true
215    }
216}
217
218/// STDDEV(column) - computes population standard deviation
219pub struct StdDevFunction;
220
221impl AggregateFunction for StdDevFunction {
222    fn name(&self) -> &'static str {
223        "STDDEV"
224    }
225
226    fn init(&self) -> AggregateState {
227        AggregateState::Variance(VarianceState::new())
228    }
229
230    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
231        if let AggregateState::Variance(ref mut var_state) = state {
232            var_state.add(value)?;
233        }
234        Ok(())
235    }
236
237    fn finalize(&self, state: AggregateState) -> DataValue {
238        if let AggregateState::Variance(var_state) = state {
239            var_state.finalize_stddev()
240        } else {
241            DataValue::Null
242        }
243    }
244
245    fn requires_numeric(&self) -> bool {
246        true
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_count_star() {
256        let func = CountStarFunction;
257        let mut state = func.init();
258
259        // COUNT(*) counts everything including nulls
260        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
261        func.accumulate(&mut state, &DataValue::Null).unwrap();
262        func.accumulate(&mut state, &DataValue::String("test".to_string()))
263            .unwrap();
264
265        let result = func.finalize(state);
266        assert_eq!(result, DataValue::Integer(3));
267    }
268
269    #[test]
270    fn test_count_column() {
271        let func = CountFunction;
272        let mut state = func.init();
273
274        // COUNT(column) skips nulls
275        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
276        func.accumulate(&mut state, &DataValue::Null).unwrap();
277        func.accumulate(&mut state, &DataValue::String("test".to_string()))
278            .unwrap();
279        func.accumulate(&mut state, &DataValue::Null).unwrap();
280
281        let result = func.finalize(state);
282        assert_eq!(result, DataValue::Integer(2));
283    }
284
285    #[test]
286    fn test_sum_integers() {
287        let func = SumFunction;
288        let mut state = func.init();
289
290        func.accumulate(&mut state, &DataValue::Integer(10))
291            .unwrap();
292        func.accumulate(&mut state, &DataValue::Integer(20))
293            .unwrap();
294        func.accumulate(&mut state, &DataValue::Integer(30))
295            .unwrap();
296        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
297
298        let result = func.finalize(state);
299        assert_eq!(result, DataValue::Integer(60));
300    }
301
302    #[test]
303    fn test_sum_mixed() {
304        let func = SumFunction;
305        let mut state = func.init();
306
307        func.accumulate(&mut state, &DataValue::Integer(10))
308            .unwrap();
309        func.accumulate(&mut state, &DataValue::Float(20.5))
310            .unwrap(); // Converts to float
311        func.accumulate(&mut state, &DataValue::Integer(30))
312            .unwrap();
313
314        let result = func.finalize(state);
315        match result {
316            DataValue::Float(f) => assert!((f - 60.5).abs() < 0.001),
317            _ => panic!("Expected Float result"),
318        }
319    }
320
321    #[test]
322    fn test_avg() {
323        let func = AvgFunction;
324        let mut state = func.init();
325
326        func.accumulate(&mut state, &DataValue::Integer(10))
327            .unwrap();
328        func.accumulate(&mut state, &DataValue::Integer(20))
329            .unwrap();
330        func.accumulate(&mut state, &DataValue::Integer(30))
331            .unwrap();
332        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
333
334        let result = func.finalize(state);
335        match result {
336            DataValue::Float(f) => assert!((f - 20.0).abs() < 0.001),
337            _ => panic!("Expected Float result"),
338        }
339    }
340
341    #[test]
342    fn test_min() {
343        let func = MinFunction;
344        let mut state = func.init();
345
346        func.accumulate(&mut state, &DataValue::Integer(30))
347            .unwrap();
348        func.accumulate(&mut state, &DataValue::Integer(10))
349            .unwrap();
350        func.accumulate(&mut state, &DataValue::Integer(20))
351            .unwrap();
352        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
353
354        let result = func.finalize(state);
355        assert_eq!(result, DataValue::Integer(10));
356    }
357
358    #[test]
359    fn test_max() {
360        let func = MaxFunction;
361        let mut state = func.init();
362
363        func.accumulate(&mut state, &DataValue::Integer(10))
364            .unwrap();
365        func.accumulate(&mut state, &DataValue::Integer(30))
366            .unwrap();
367        func.accumulate(&mut state, &DataValue::Integer(20))
368            .unwrap();
369        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
370
371        let result = func.finalize(state);
372        assert_eq!(result, DataValue::Integer(30));
373    }
374
375    #[test]
376    fn test_max_strings() {
377        let func = MaxFunction;
378        let mut state = func.init();
379
380        func.accumulate(&mut state, &DataValue::String("apple".to_string()))
381            .unwrap();
382        func.accumulate(&mut state, &DataValue::String("zebra".to_string()))
383            .unwrap();
384        func.accumulate(&mut state, &DataValue::String("banana".to_string()))
385            .unwrap();
386
387        let result = func.finalize(state);
388        assert_eq!(result, DataValue::String("zebra".to_string()));
389    }
390
391    #[test]
392    fn test_variance() {
393        let func = VarianceFunction;
394        let mut state = func.init();
395
396        // Test data: [2, 4, 6, 8, 10]
397        // Mean = 6, Variance = 8
398        func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
399        func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
400        func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
401        func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
402        func.accumulate(&mut state, &DataValue::Integer(10))
403            .unwrap();
404
405        let result = func.finalize(state);
406        match result {
407            DataValue::Float(f) => assert!((f - 8.0).abs() < 0.001),
408            _ => panic!("Expected Float result"),
409        }
410    }
411
412    #[test]
413    fn test_stddev() {
414        let func = StdDevFunction;
415        let mut state = func.init();
416
417        // Test data: [2, 4, 6, 8, 10]
418        // Mean = 6, Variance = 8, StdDev = sqrt(8) ≈ 2.828
419        func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
420        func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
421        func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
422        func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
423        func.accumulate(&mut state, &DataValue::Integer(10))
424            .unwrap();
425
426        let result = func.finalize(state);
427        match result {
428            DataValue::Float(f) => assert!((f - 2.8284271247461903).abs() < 0.001),
429            _ => panic!("Expected Float result"),
430        }
431    }
432
433    #[test]
434    fn test_variance_with_nulls() {
435        let func = VarianceFunction;
436        let mut state = func.init();
437
438        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
439        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Should be ignored
440        func.accumulate(&mut state, &DataValue::Integer(10))
441            .unwrap();
442        func.accumulate(&mut state, &DataValue::Integer(15))
443            .unwrap();
444
445        let result = func.finalize(state);
446        match result {
447            DataValue::Float(f) => {
448                // Mean = 10, values = [5, 10, 15]
449                // Variance = ((5-10)² + (10-10)² + (15-10)²) / 3 = (25 + 0 + 25) / 3 ≈ 16.67
450                assert!((f - 16.666666666666668).abs() < 0.001);
451            }
452            _ => panic!("Expected Float result"),
453        }
454    }
455}