sql_cli/sql/aggregates/
functions.rs1use anyhow::Result;
4
5use super::{AggregateFunction, AggregateState, AvgState, MinMaxState, SumState};
6use crate::data::datatable::DataValue;
7
8pub struct CountStarFunction;
10
11impl AggregateFunction for CountStarFunction {
12 fn name(&self) -> &str {
13 "COUNT_STAR"
14 }
15
16 fn init(&self) -> AggregateState {
17 AggregateState::Count(0)
18 }
19
20 fn accumulate(&self, state: &mut AggregateState, _value: &DataValue) -> Result<()> {
21 if let AggregateState::Count(ref mut count) = state {
22 *count += 1;
23 }
24 Ok(())
25 }
26
27 fn finalize(&self, state: AggregateState) -> DataValue {
28 if let AggregateState::Count(count) = state {
29 DataValue::Integer(count)
30 } else {
31 DataValue::Null
32 }
33 }
34}
35
36pub struct CountFunction;
38
39impl AggregateFunction for CountFunction {
40 fn name(&self) -> &str {
41 "COUNT"
42 }
43
44 fn init(&self) -> AggregateState {
45 AggregateState::Count(0)
46 }
47
48 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
49 if let AggregateState::Count(ref mut count) = state {
50 if !matches!(value, DataValue::Null) {
51 *count += 1;
52 }
53 }
54 Ok(())
55 }
56
57 fn finalize(&self, state: AggregateState) -> DataValue {
58 if let AggregateState::Count(count) = state {
59 DataValue::Integer(count)
60 } else {
61 DataValue::Null
62 }
63 }
64}
65
66pub struct SumFunction;
68
69impl AggregateFunction for SumFunction {
70 fn name(&self) -> &str {
71 "SUM"
72 }
73
74 fn init(&self) -> AggregateState {
75 AggregateState::Sum(SumState::new())
76 }
77
78 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
79 if let AggregateState::Sum(ref mut sum_state) = state {
80 sum_state.add(value)?;
81 }
82 Ok(())
83 }
84
85 fn finalize(&self, state: AggregateState) -> DataValue {
86 if let AggregateState::Sum(sum_state) = state {
87 sum_state.finalize()
88 } else {
89 DataValue::Null
90 }
91 }
92
93 fn requires_numeric(&self) -> bool {
94 true
95 }
96}
97
98pub struct AvgFunction;
100
101impl AggregateFunction for AvgFunction {
102 fn name(&self) -> &str {
103 "AVG"
104 }
105
106 fn init(&self) -> AggregateState {
107 AggregateState::Avg(AvgState::new())
108 }
109
110 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
111 if let AggregateState::Avg(ref mut avg_state) = state {
112 avg_state.add(value)?;
113 }
114 Ok(())
115 }
116
117 fn finalize(&self, state: AggregateState) -> DataValue {
118 if let AggregateState::Avg(avg_state) = state {
119 avg_state.finalize()
120 } else {
121 DataValue::Null
122 }
123 }
124
125 fn requires_numeric(&self) -> bool {
126 true
127 }
128}
129
130pub struct MinFunction;
132
133impl AggregateFunction for MinFunction {
134 fn name(&self) -> &str {
135 "MIN"
136 }
137
138 fn init(&self) -> AggregateState {
139 AggregateState::MinMax(MinMaxState::new(true))
140 }
141
142 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
143 if let AggregateState::MinMax(ref mut minmax_state) = state {
144 minmax_state.add(value)?;
145 }
146 Ok(())
147 }
148
149 fn finalize(&self, state: AggregateState) -> DataValue {
150 if let AggregateState::MinMax(minmax_state) = state {
151 minmax_state.finalize()
152 } else {
153 DataValue::Null
154 }
155 }
156}
157
158pub struct MaxFunction;
160
161impl AggregateFunction for MaxFunction {
162 fn name(&self) -> &str {
163 "MAX"
164 }
165
166 fn init(&self) -> AggregateState {
167 AggregateState::MinMax(MinMaxState::new(false))
168 }
169
170 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
171 if let AggregateState::MinMax(ref mut minmax_state) = state {
172 minmax_state.add(value)?;
173 }
174 Ok(())
175 }
176
177 fn finalize(&self, state: AggregateState) -> DataValue {
178 if let AggregateState::MinMax(minmax_state) = state {
179 minmax_state.finalize()
180 } else {
181 DataValue::Null
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_count_star() {
192 let func = CountStarFunction;
193 let mut state = func.init();
194
195 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
197 func.accumulate(&mut state, &DataValue::Null).unwrap();
198 func.accumulate(&mut state, &DataValue::String("test".to_string()))
199 .unwrap();
200
201 let result = func.finalize(state);
202 assert_eq!(result, DataValue::Integer(3));
203 }
204
205 #[test]
206 fn test_count_column() {
207 let func = CountFunction;
208 let mut state = func.init();
209
210 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
212 func.accumulate(&mut state, &DataValue::Null).unwrap();
213 func.accumulate(&mut state, &DataValue::String("test".to_string()))
214 .unwrap();
215 func.accumulate(&mut state, &DataValue::Null).unwrap();
216
217 let result = func.finalize(state);
218 assert_eq!(result, DataValue::Integer(2));
219 }
220
221 #[test]
222 fn test_sum_integers() {
223 let func = SumFunction;
224 let mut state = func.init();
225
226 func.accumulate(&mut state, &DataValue::Integer(10))
227 .unwrap();
228 func.accumulate(&mut state, &DataValue::Integer(20))
229 .unwrap();
230 func.accumulate(&mut state, &DataValue::Integer(30))
231 .unwrap();
232 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
235 assert_eq!(result, DataValue::Integer(60));
236 }
237
238 #[test]
239 fn test_sum_mixed() {
240 let func = SumFunction;
241 let mut state = func.init();
242
243 func.accumulate(&mut state, &DataValue::Integer(10))
244 .unwrap();
245 func.accumulate(&mut state, &DataValue::Float(20.5))
246 .unwrap(); func.accumulate(&mut state, &DataValue::Integer(30))
248 .unwrap();
249
250 let result = func.finalize(state);
251 match result {
252 DataValue::Float(f) => assert!((f - 60.5).abs() < 0.001),
253 _ => panic!("Expected Float result"),
254 }
255 }
256
257 #[test]
258 fn test_avg() {
259 let func = AvgFunction;
260 let mut state = func.init();
261
262 func.accumulate(&mut state, &DataValue::Integer(10))
263 .unwrap();
264 func.accumulate(&mut state, &DataValue::Integer(20))
265 .unwrap();
266 func.accumulate(&mut state, &DataValue::Integer(30))
267 .unwrap();
268 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
271 match result {
272 DataValue::Float(f) => assert!((f - 20.0).abs() < 0.001),
273 _ => panic!("Expected Float result"),
274 }
275 }
276
277 #[test]
278 fn test_min() {
279 let func = MinFunction;
280 let mut state = func.init();
281
282 func.accumulate(&mut state, &DataValue::Integer(30))
283 .unwrap();
284 func.accumulate(&mut state, &DataValue::Integer(10))
285 .unwrap();
286 func.accumulate(&mut state, &DataValue::Integer(20))
287 .unwrap();
288 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
291 assert_eq!(result, DataValue::Integer(10));
292 }
293
294 #[test]
295 fn test_max() {
296 let func = MaxFunction;
297 let mut state = func.init();
298
299 func.accumulate(&mut state, &DataValue::Integer(10))
300 .unwrap();
301 func.accumulate(&mut state, &DataValue::Integer(30))
302 .unwrap();
303 func.accumulate(&mut state, &DataValue::Integer(20))
304 .unwrap();
305 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
308 assert_eq!(result, DataValue::Integer(30));
309 }
310
311 #[test]
312 fn test_max_strings() {
313 let func = MaxFunction;
314 let mut state = func.init();
315
316 func.accumulate(&mut state, &DataValue::String("apple".to_string()))
317 .unwrap();
318 func.accumulate(&mut state, &DataValue::String("zebra".to_string()))
319 .unwrap();
320 func.accumulate(&mut state, &DataValue::String("banana".to_string()))
321 .unwrap();
322
323 let result = func.finalize(state);
324 assert_eq!(result, DataValue::String("zebra".to_string()));
325 }
326}