sql_cli/sql/aggregates/
functions.rs

1//! Concrete implementations of aggregate functions
2
3use anyhow::Result;
4
5use super::{AggregateFunction, AggregateState, AvgState, MinMaxState, SumState};
6use crate::data::datatable::DataValue;
7
8/// COUNT(*) - counts all rows including nulls
9pub struct CountStarFunction;
10
11impl AggregateFunction for CountStarFunction {
12    fn name(&self) -> &str {
13        "COUNT_STAR"
14    }
15
16    fn init(&self) -> AggregateState {
17        AggregateState::Count(0)
18    }
19
20    fn accumulate(&self, state: &mut AggregateState, _value: &DataValue) -> Result<()> {
21        if let AggregateState::Count(ref mut count) = state {
22            *count += 1;
23        }
24        Ok(())
25    }
26
27    fn finalize(&self, state: AggregateState) -> DataValue {
28        if let AggregateState::Count(count) = state {
29            DataValue::Integer(count)
30        } else {
31            DataValue::Null
32        }
33    }
34}
35
36/// COUNT(column) - counts non-null values
37pub struct CountFunction;
38
39impl AggregateFunction for CountFunction {
40    fn name(&self) -> &str {
41        "COUNT"
42    }
43
44    fn init(&self) -> AggregateState {
45        AggregateState::Count(0)
46    }
47
48    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
49        if let AggregateState::Count(ref mut count) = state {
50            if !matches!(value, DataValue::Null) {
51                *count += 1;
52            }
53        }
54        Ok(())
55    }
56
57    fn finalize(&self, state: AggregateState) -> DataValue {
58        if let AggregateState::Count(count) = state {
59            DataValue::Integer(count)
60        } else {
61            DataValue::Null
62        }
63    }
64}
65
66/// SUM(column) - sums numeric values
67pub struct SumFunction;
68
69impl AggregateFunction for SumFunction {
70    fn name(&self) -> &str {
71        "SUM"
72    }
73
74    fn init(&self) -> AggregateState {
75        AggregateState::Sum(SumState::new())
76    }
77
78    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
79        if let AggregateState::Sum(ref mut sum_state) = state {
80            sum_state.add(value)?;
81        }
82        Ok(())
83    }
84
85    fn finalize(&self, state: AggregateState) -> DataValue {
86        if let AggregateState::Sum(sum_state) = state {
87            sum_state.finalize()
88        } else {
89            DataValue::Null
90        }
91    }
92
93    fn requires_numeric(&self) -> bool {
94        true
95    }
96}
97
98/// AVG(column) - averages numeric values
99pub struct AvgFunction;
100
101impl AggregateFunction for AvgFunction {
102    fn name(&self) -> &str {
103        "AVG"
104    }
105
106    fn init(&self) -> AggregateState {
107        AggregateState::Avg(AvgState::new())
108    }
109
110    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
111        if let AggregateState::Avg(ref mut avg_state) = state {
112            avg_state.add(value)?;
113        }
114        Ok(())
115    }
116
117    fn finalize(&self, state: AggregateState) -> DataValue {
118        if let AggregateState::Avg(avg_state) = state {
119            avg_state.finalize()
120        } else {
121            DataValue::Null
122        }
123    }
124
125    fn requires_numeric(&self) -> bool {
126        true
127    }
128}
129
130/// MIN(column) - finds minimum value
131pub struct MinFunction;
132
133impl AggregateFunction for MinFunction {
134    fn name(&self) -> &str {
135        "MIN"
136    }
137
138    fn init(&self) -> AggregateState {
139        AggregateState::MinMax(MinMaxState::new(true))
140    }
141
142    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
143        if let AggregateState::MinMax(ref mut minmax_state) = state {
144            minmax_state.add(value)?;
145        }
146        Ok(())
147    }
148
149    fn finalize(&self, state: AggregateState) -> DataValue {
150        if let AggregateState::MinMax(minmax_state) = state {
151            minmax_state.finalize()
152        } else {
153            DataValue::Null
154        }
155    }
156}
157
158/// MAX(column) - finds maximum value
159pub struct MaxFunction;
160
161impl AggregateFunction for MaxFunction {
162    fn name(&self) -> &str {
163        "MAX"
164    }
165
166    fn init(&self) -> AggregateState {
167        AggregateState::MinMax(MinMaxState::new(false))
168    }
169
170    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
171        if let AggregateState::MinMax(ref mut minmax_state) = state {
172            minmax_state.add(value)?;
173        }
174        Ok(())
175    }
176
177    fn finalize(&self, state: AggregateState) -> DataValue {
178        if let AggregateState::MinMax(minmax_state) = state {
179            minmax_state.finalize()
180        } else {
181            DataValue::Null
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_count_star() {
192        let func = CountStarFunction;
193        let mut state = func.init();
194
195        // COUNT(*) counts everything including nulls
196        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
197        func.accumulate(&mut state, &DataValue::Null).unwrap();
198        func.accumulate(&mut state, &DataValue::String("test".to_string()))
199            .unwrap();
200
201        let result = func.finalize(state);
202        assert_eq!(result, DataValue::Integer(3));
203    }
204
205    #[test]
206    fn test_count_column() {
207        let func = CountFunction;
208        let mut state = func.init();
209
210        // COUNT(column) skips nulls
211        func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
212        func.accumulate(&mut state, &DataValue::Null).unwrap();
213        func.accumulate(&mut state, &DataValue::String("test".to_string()))
214            .unwrap();
215        func.accumulate(&mut state, &DataValue::Null).unwrap();
216
217        let result = func.finalize(state);
218        assert_eq!(result, DataValue::Integer(2));
219    }
220
221    #[test]
222    fn test_sum_integers() {
223        let func = SumFunction;
224        let mut state = func.init();
225
226        func.accumulate(&mut state, &DataValue::Integer(10))
227            .unwrap();
228        func.accumulate(&mut state, &DataValue::Integer(20))
229            .unwrap();
230        func.accumulate(&mut state, &DataValue::Integer(30))
231            .unwrap();
232        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
233
234        let result = func.finalize(state);
235        assert_eq!(result, DataValue::Integer(60));
236    }
237
238    #[test]
239    fn test_sum_mixed() {
240        let func = SumFunction;
241        let mut state = func.init();
242
243        func.accumulate(&mut state, &DataValue::Integer(10))
244            .unwrap();
245        func.accumulate(&mut state, &DataValue::Float(20.5))
246            .unwrap(); // Converts to float
247        func.accumulate(&mut state, &DataValue::Integer(30))
248            .unwrap();
249
250        let result = func.finalize(state);
251        match result {
252            DataValue::Float(f) => assert!((f - 60.5).abs() < 0.001),
253            _ => panic!("Expected Float result"),
254        }
255    }
256
257    #[test]
258    fn test_avg() {
259        let func = AvgFunction;
260        let mut state = func.init();
261
262        func.accumulate(&mut state, &DataValue::Integer(10))
263            .unwrap();
264        func.accumulate(&mut state, &DataValue::Integer(20))
265            .unwrap();
266        func.accumulate(&mut state, &DataValue::Integer(30))
267            .unwrap();
268        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
269
270        let result = func.finalize(state);
271        match result {
272            DataValue::Float(f) => assert!((f - 20.0).abs() < 0.001),
273            _ => panic!("Expected Float result"),
274        }
275    }
276
277    #[test]
278    fn test_min() {
279        let func = MinFunction;
280        let mut state = func.init();
281
282        func.accumulate(&mut state, &DataValue::Integer(30))
283            .unwrap();
284        func.accumulate(&mut state, &DataValue::Integer(10))
285            .unwrap();
286        func.accumulate(&mut state, &DataValue::Integer(20))
287            .unwrap();
288        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
289
290        let result = func.finalize(state);
291        assert_eq!(result, DataValue::Integer(10));
292    }
293
294    #[test]
295    fn test_max() {
296        let func = MaxFunction;
297        let mut state = func.init();
298
299        func.accumulate(&mut state, &DataValue::Integer(10))
300            .unwrap();
301        func.accumulate(&mut state, &DataValue::Integer(30))
302            .unwrap();
303        func.accumulate(&mut state, &DataValue::Integer(20))
304            .unwrap();
305        func.accumulate(&mut state, &DataValue::Null).unwrap(); // Ignored
306
307        let result = func.finalize(state);
308        assert_eq!(result, DataValue::Integer(30));
309    }
310
311    #[test]
312    fn test_max_strings() {
313        let func = MaxFunction;
314        let mut state = func.init();
315
316        func.accumulate(&mut state, &DataValue::String("apple".to_string()))
317            .unwrap();
318        func.accumulate(&mut state, &DataValue::String("zebra".to_string()))
319            .unwrap();
320        func.accumulate(&mut state, &DataValue::String("banana".to_string()))
321            .unwrap();
322
323        let result = func.finalize(state);
324        assert_eq!(result, DataValue::String("zebra".to_string()));
325    }
326}