1use anyhow::Result;
4
5use super::{AggregateFunction, AggregateState, AvgState, MinMaxState, SumState, VarianceState};
6use crate::data::datatable::DataValue;
7
8pub struct CountStarFunction;
10
11impl AggregateFunction for CountStarFunction {
12 fn name(&self) -> &'static str {
13 "COUNT_STAR"
14 }
15
16 fn init(&self) -> AggregateState {
17 AggregateState::Count(0)
18 }
19
20 fn accumulate(&self, state: &mut AggregateState, _value: &DataValue) -> Result<()> {
21 if let AggregateState::Count(ref mut count) = state {
22 *count += 1;
23 }
24 Ok(())
25 }
26
27 fn finalize(&self, state: AggregateState) -> DataValue {
28 if let AggregateState::Count(count) = state {
29 DataValue::Integer(count)
30 } else {
31 DataValue::Null
32 }
33 }
34}
35
36pub struct CountFunction;
38
39impl AggregateFunction for CountFunction {
40 fn name(&self) -> &'static str {
41 "COUNT"
42 }
43
44 fn init(&self) -> AggregateState {
45 AggregateState::Count(0)
46 }
47
48 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
49 if let AggregateState::Count(ref mut count) = state {
50 if !matches!(value, DataValue::Null) {
51 *count += 1;
52 }
53 }
54 Ok(())
55 }
56
57 fn finalize(&self, state: AggregateState) -> DataValue {
58 if let AggregateState::Count(count) = state {
59 DataValue::Integer(count)
60 } else {
61 DataValue::Null
62 }
63 }
64}
65
66pub struct SumFunction;
68
69impl AggregateFunction for SumFunction {
70 fn name(&self) -> &'static str {
71 "SUM"
72 }
73
74 fn init(&self) -> AggregateState {
75 AggregateState::Sum(SumState::new())
76 }
77
78 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
79 if let AggregateState::Sum(ref mut sum_state) = state {
80 sum_state.add(value)?;
81 }
82 Ok(())
83 }
84
85 fn finalize(&self, state: AggregateState) -> DataValue {
86 if let AggregateState::Sum(sum_state) = state {
87 sum_state.finalize()
88 } else {
89 DataValue::Null
90 }
91 }
92
93 fn requires_numeric(&self) -> bool {
94 true
95 }
96}
97
98pub struct AvgFunction;
100
101impl AggregateFunction for AvgFunction {
102 fn name(&self) -> &'static str {
103 "AVG"
104 }
105
106 fn init(&self) -> AggregateState {
107 AggregateState::Avg(AvgState::new())
108 }
109
110 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
111 if let AggregateState::Avg(ref mut avg_state) = state {
112 avg_state.add(value)?;
113 }
114 Ok(())
115 }
116
117 fn finalize(&self, state: AggregateState) -> DataValue {
118 if let AggregateState::Avg(avg_state) = state {
119 avg_state.finalize()
120 } else {
121 DataValue::Null
122 }
123 }
124
125 fn requires_numeric(&self) -> bool {
126 true
127 }
128}
129
130pub struct MinFunction;
132
133impl AggregateFunction for MinFunction {
134 fn name(&self) -> &'static str {
135 "MIN"
136 }
137
138 fn init(&self) -> AggregateState {
139 AggregateState::MinMax(MinMaxState::new(true))
140 }
141
142 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
143 if let AggregateState::MinMax(ref mut minmax_state) = state {
144 minmax_state.add(value)?;
145 }
146 Ok(())
147 }
148
149 fn finalize(&self, state: AggregateState) -> DataValue {
150 if let AggregateState::MinMax(minmax_state) = state {
151 minmax_state.finalize()
152 } else {
153 DataValue::Null
154 }
155 }
156}
157
158pub struct MaxFunction;
160
161impl AggregateFunction for MaxFunction {
162 fn name(&self) -> &'static str {
163 "MAX"
164 }
165
166 fn init(&self) -> AggregateState {
167 AggregateState::MinMax(MinMaxState::new(false))
168 }
169
170 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
171 if let AggregateState::MinMax(ref mut minmax_state) = state {
172 minmax_state.add(value)?;
173 }
174 Ok(())
175 }
176
177 fn finalize(&self, state: AggregateState) -> DataValue {
178 if let AggregateState::MinMax(minmax_state) = state {
179 minmax_state.finalize()
180 } else {
181 DataValue::Null
182 }
183 }
184}
185
186pub struct VarianceFunction;
188
189impl AggregateFunction for VarianceFunction {
190 fn name(&self) -> &'static str {
191 "VARIANCE"
192 }
193
194 fn init(&self) -> AggregateState {
195 AggregateState::Variance(VarianceState::new())
196 }
197
198 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
199 if let AggregateState::Variance(ref mut var_state) = state {
200 var_state.add(value)?;
201 }
202 Ok(())
203 }
204
205 fn finalize(&self, state: AggregateState) -> DataValue {
206 if let AggregateState::Variance(var_state) = state {
207 var_state.finalize_variance()
208 } else {
209 DataValue::Null
210 }
211 }
212
213 fn requires_numeric(&self) -> bool {
214 true
215 }
216}
217
218pub struct StdDevFunction;
220
221impl AggregateFunction for StdDevFunction {
222 fn name(&self) -> &'static str {
223 "STDDEV"
224 }
225
226 fn init(&self) -> AggregateState {
227 AggregateState::Variance(VarianceState::new())
228 }
229
230 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()> {
231 if let AggregateState::Variance(ref mut var_state) = state {
232 var_state.add(value)?;
233 }
234 Ok(())
235 }
236
237 fn finalize(&self, state: AggregateState) -> DataValue {
238 if let AggregateState::Variance(var_state) = state {
239 var_state.finalize_stddev()
240 } else {
241 DataValue::Null
242 }
243 }
244
245 fn requires_numeric(&self) -> bool {
246 true
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_count_star() {
256 let func = CountStarFunction;
257 let mut state = func.init();
258
259 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
261 func.accumulate(&mut state, &DataValue::Null).unwrap();
262 func.accumulate(&mut state, &DataValue::String("test".to_string()))
263 .unwrap();
264
265 let result = func.finalize(state);
266 assert_eq!(result, DataValue::Integer(3));
267 }
268
269 #[test]
270 fn test_count_column() {
271 let func = CountFunction;
272 let mut state = func.init();
273
274 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
276 func.accumulate(&mut state, &DataValue::Null).unwrap();
277 func.accumulate(&mut state, &DataValue::String("test".to_string()))
278 .unwrap();
279 func.accumulate(&mut state, &DataValue::Null).unwrap();
280
281 let result = func.finalize(state);
282 assert_eq!(result, DataValue::Integer(2));
283 }
284
285 #[test]
286 fn test_sum_integers() {
287 let func = SumFunction;
288 let mut state = func.init();
289
290 func.accumulate(&mut state, &DataValue::Integer(10))
291 .unwrap();
292 func.accumulate(&mut state, &DataValue::Integer(20))
293 .unwrap();
294 func.accumulate(&mut state, &DataValue::Integer(30))
295 .unwrap();
296 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
299 assert_eq!(result, DataValue::Integer(60));
300 }
301
302 #[test]
303 fn test_sum_mixed() {
304 let func = SumFunction;
305 let mut state = func.init();
306
307 func.accumulate(&mut state, &DataValue::Integer(10))
308 .unwrap();
309 func.accumulate(&mut state, &DataValue::Float(20.5))
310 .unwrap(); func.accumulate(&mut state, &DataValue::Integer(30))
312 .unwrap();
313
314 let result = func.finalize(state);
315 match result {
316 DataValue::Float(f) => assert!((f - 60.5).abs() < 0.001),
317 _ => panic!("Expected Float result"),
318 }
319 }
320
321 #[test]
322 fn test_avg() {
323 let func = AvgFunction;
324 let mut state = func.init();
325
326 func.accumulate(&mut state, &DataValue::Integer(10))
327 .unwrap();
328 func.accumulate(&mut state, &DataValue::Integer(20))
329 .unwrap();
330 func.accumulate(&mut state, &DataValue::Integer(30))
331 .unwrap();
332 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
335 match result {
336 DataValue::Float(f) => assert!((f - 20.0).abs() < 0.001),
337 _ => panic!("Expected Float result"),
338 }
339 }
340
341 #[test]
342 fn test_min() {
343 let func = MinFunction;
344 let mut state = func.init();
345
346 func.accumulate(&mut state, &DataValue::Integer(30))
347 .unwrap();
348 func.accumulate(&mut state, &DataValue::Integer(10))
349 .unwrap();
350 func.accumulate(&mut state, &DataValue::Integer(20))
351 .unwrap();
352 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
355 assert_eq!(result, DataValue::Integer(10));
356 }
357
358 #[test]
359 fn test_max() {
360 let func = MaxFunction;
361 let mut state = func.init();
362
363 func.accumulate(&mut state, &DataValue::Integer(10))
364 .unwrap();
365 func.accumulate(&mut state, &DataValue::Integer(30))
366 .unwrap();
367 func.accumulate(&mut state, &DataValue::Integer(20))
368 .unwrap();
369 func.accumulate(&mut state, &DataValue::Null).unwrap(); let result = func.finalize(state);
372 assert_eq!(result, DataValue::Integer(30));
373 }
374
375 #[test]
376 fn test_max_strings() {
377 let func = MaxFunction;
378 let mut state = func.init();
379
380 func.accumulate(&mut state, &DataValue::String("apple".to_string()))
381 .unwrap();
382 func.accumulate(&mut state, &DataValue::String("zebra".to_string()))
383 .unwrap();
384 func.accumulate(&mut state, &DataValue::String("banana".to_string()))
385 .unwrap();
386
387 let result = func.finalize(state);
388 assert_eq!(result, DataValue::String("zebra".to_string()));
389 }
390
391 #[test]
392 fn test_variance() {
393 let func = VarianceFunction;
394 let mut state = func.init();
395
396 func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
399 func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
400 func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
401 func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
402 func.accumulate(&mut state, &DataValue::Integer(10))
403 .unwrap();
404
405 let result = func.finalize(state);
406 match result {
407 DataValue::Float(f) => assert!((f - 8.0).abs() < 0.001),
408 _ => panic!("Expected Float result"),
409 }
410 }
411
412 #[test]
413 fn test_stddev() {
414 let func = StdDevFunction;
415 let mut state = func.init();
416
417 func.accumulate(&mut state, &DataValue::Integer(2)).unwrap();
420 func.accumulate(&mut state, &DataValue::Integer(4)).unwrap();
421 func.accumulate(&mut state, &DataValue::Integer(6)).unwrap();
422 func.accumulate(&mut state, &DataValue::Integer(8)).unwrap();
423 func.accumulate(&mut state, &DataValue::Integer(10))
424 .unwrap();
425
426 let result = func.finalize(state);
427 match result {
428 DataValue::Float(f) => assert!((f - 2.8284271247461903).abs() < 0.001),
429 _ => panic!("Expected Float result"),
430 }
431 }
432
433 #[test]
434 fn test_variance_with_nulls() {
435 let func = VarianceFunction;
436 let mut state = func.init();
437
438 func.accumulate(&mut state, &DataValue::Integer(5)).unwrap();
439 func.accumulate(&mut state, &DataValue::Null).unwrap(); func.accumulate(&mut state, &DataValue::Integer(10))
441 .unwrap();
442 func.accumulate(&mut state, &DataValue::Integer(15))
443 .unwrap();
444
445 let result = func.finalize(state);
446 match result {
447 DataValue::Float(f) => {
448 assert!((f - 16.666666666666668).abs() < 0.001);
451 }
452 _ => panic!("Expected Float result"),
453 }
454 }
455}