sql_cli/sql/aggregate_functions/
mod.rs1use anyhow::{anyhow, Result};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::data::datatable::DataValue;
10
11pub trait AggregateState: Send + Sync {
14 fn accumulate(&mut self, value: &DataValue) -> Result<()>;
16
17 fn finalize(self: Box<Self>) -> DataValue;
19
20 fn clone_box(&self) -> Box<dyn AggregateState>;
22
23 fn reset(&mut self);
25}
26
27pub trait AggregateFunction: Send + Sync {
30 fn name(&self) -> &str;
32
33 fn description(&self) -> &str;
35
36 fn create_state(&self) -> Box<dyn AggregateState>;
38
39 fn supports_distinct(&self) -> bool {
41 true }
43
44 fn set_parameters(&self, _params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
46 Ok(Box::new(DummyClone(self.name().to_string())))
48 }
49}
50
51struct DummyClone(String);
53impl AggregateFunction for DummyClone {
54 fn name(&self) -> &str {
55 &self.0
56 }
57 fn description(&self) -> &str {
58 ""
59 }
60 fn create_state(&self) -> Box<dyn AggregateState> {
61 panic!("DummyClone should not be used")
62 }
63}
64
65pub struct AggregateFunctionRegistry {
67 functions: HashMap<String, Arc<Box<dyn AggregateFunction>>>,
68}
69
70impl AggregateFunctionRegistry {
71 pub fn new() -> Self {
72 let mut registry = Self {
73 functions: HashMap::new(),
74 };
75 registry.register_builtin_functions();
76 registry
77 }
78
79 pub fn register(&mut self, function: Box<dyn AggregateFunction>) {
81 let name = function.name().to_uppercase();
82 self.functions.insert(name, Arc::new(function));
83 }
84
85 pub fn get(&self, name: &str) -> Option<Arc<Box<dyn AggregateFunction>>> {
87 self.functions.get(&name.to_uppercase()).cloned()
88 }
89
90 pub fn contains(&self, name: &str) -> bool {
92 self.functions.contains_key(&name.to_uppercase())
93 }
94
95 pub fn list_functions(&self) -> Vec<String> {
97 self.functions.keys().cloned().collect()
98 }
99
100 fn register_builtin_functions(&mut self) {
102 self.register(Box::new(CountFunction));
104 self.register(Box::new(SumFunction));
105 self.register(Box::new(AvgFunction));
106 self.register(Box::new(MinFunction));
107 self.register(Box::new(MaxFunction));
108
109 self.register(Box::new(StringAggFunction::new()));
111
112 }
116}
117
118struct CountFunction;
121
122impl AggregateFunction for CountFunction {
123 fn name(&self) -> &str {
124 "COUNT"
125 }
126
127 fn description(&self) -> &str {
128 "Count the number of non-null values or rows"
129 }
130
131 fn create_state(&self) -> Box<dyn AggregateState> {
132 Box::new(CountState { count: 0 })
133 }
134}
135
136struct CountState {
137 count: i64,
138}
139
140impl AggregateState for CountState {
141 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
142 if !matches!(value, DataValue::Null) {
144 self.count += 1;
145 }
146 Ok(())
147 }
148
149 fn finalize(self: Box<Self>) -> DataValue {
150 DataValue::Integer(self.count)
151 }
152
153 fn clone_box(&self) -> Box<dyn AggregateState> {
154 Box::new(CountState { count: self.count })
155 }
156
157 fn reset(&mut self) {
158 self.count = 0;
159 }
160}
161
162struct SumFunction;
165
166impl AggregateFunction for SumFunction {
167 fn name(&self) -> &str {
168 "SUM"
169 }
170
171 fn description(&self) -> &str {
172 "Calculate the sum of values"
173 }
174
175 fn create_state(&self) -> Box<dyn AggregateState> {
176 Box::new(SumState {
177 int_sum: None,
178 float_sum: None,
179 has_values: false,
180 })
181 }
182}
183
184struct SumState {
185 int_sum: Option<i64>,
186 float_sum: Option<f64>,
187 has_values: bool,
188}
189
190impl AggregateState for SumState {
191 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
192 match value {
193 DataValue::Null => Ok(()), DataValue::Integer(n) => {
195 self.has_values = true;
196 if let Some(ref mut sum) = self.int_sum {
197 *sum = sum.saturating_add(*n);
198 } else if let Some(ref mut fsum) = self.float_sum {
199 *fsum += *n as f64;
200 } else {
201 self.int_sum = Some(*n);
202 }
203 Ok(())
204 }
205 DataValue::Float(f) => {
206 self.has_values = true;
207 if let Some(isum) = self.int_sum.take() {
209 self.float_sum = Some(isum as f64 + f);
210 } else if let Some(ref mut fsum) = self.float_sum {
211 *fsum += f;
212 } else {
213 self.float_sum = Some(*f);
214 }
215 Ok(())
216 }
217 _ => Err(anyhow!("Cannot sum non-numeric value")),
218 }
219 }
220
221 fn finalize(self: Box<Self>) -> DataValue {
222 if !self.has_values {
223 return DataValue::Null;
224 }
225
226 if let Some(fsum) = self.float_sum {
227 DataValue::Float(fsum)
228 } else if let Some(isum) = self.int_sum {
229 DataValue::Integer(isum)
230 } else {
231 DataValue::Null
232 }
233 }
234
235 fn clone_box(&self) -> Box<dyn AggregateState> {
236 Box::new(SumState {
237 int_sum: self.int_sum,
238 float_sum: self.float_sum,
239 has_values: self.has_values,
240 })
241 }
242
243 fn reset(&mut self) {
244 self.int_sum = None;
245 self.float_sum = None;
246 self.has_values = false;
247 }
248}
249
250struct AvgFunction;
253
254impl AggregateFunction for AvgFunction {
255 fn name(&self) -> &str {
256 "AVG"
257 }
258
259 fn description(&self) -> &str {
260 "Calculate the average of values"
261 }
262
263 fn create_state(&self) -> Box<dyn AggregateState> {
264 Box::new(AvgState {
265 sum: SumState {
266 int_sum: None,
267 float_sum: None,
268 has_values: false,
269 },
270 count: 0,
271 })
272 }
273}
274
275struct AvgState {
276 sum: SumState,
277 count: i64,
278}
279
280impl AggregateState for AvgState {
281 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
282 if !matches!(value, DataValue::Null) {
283 self.sum.accumulate(value)?;
284 self.count += 1;
285 }
286 Ok(())
287 }
288
289 fn finalize(self: Box<Self>) -> DataValue {
290 if self.count == 0 {
291 return DataValue::Null;
292 }
293
294 let sum = Box::new(self.sum).finalize();
295 match sum {
296 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
297 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
298 _ => DataValue::Null,
299 }
300 }
301
302 fn clone_box(&self) -> Box<dyn AggregateState> {
303 Box::new(AvgState {
304 sum: SumState {
305 int_sum: self.sum.int_sum,
306 float_sum: self.sum.float_sum,
307 has_values: self.sum.has_values,
308 },
309 count: self.count,
310 })
311 }
312
313 fn reset(&mut self) {
314 self.sum.reset();
315 self.count = 0;
316 }
317}
318
319struct MinFunction;
322
323impl AggregateFunction for MinFunction {
324 fn name(&self) -> &str {
325 "MIN"
326 }
327
328 fn description(&self) -> &str {
329 "Find the minimum value"
330 }
331
332 fn create_state(&self) -> Box<dyn AggregateState> {
333 Box::new(MinMaxState {
334 is_min: true,
335 current: None,
336 })
337 }
338}
339
340struct MaxFunction;
343
344impl AggregateFunction for MaxFunction {
345 fn name(&self) -> &str {
346 "MAX"
347 }
348
349 fn description(&self) -> &str {
350 "Find the maximum value"
351 }
352
353 fn create_state(&self) -> Box<dyn AggregateState> {
354 Box::new(MinMaxState {
355 is_min: false,
356 current: None,
357 })
358 }
359}
360
361struct MinMaxState {
362 is_min: bool,
363 current: Option<DataValue>,
364}
365
366impl AggregateState for MinMaxState {
367 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
368 if matches!(value, DataValue::Null) {
369 return Ok(());
370 }
371
372 match &self.current {
373 None => {
374 self.current = Some(value.clone());
375 }
376 Some(current) => {
377 let should_update = if self.is_min {
378 value < current
379 } else {
380 value > current
381 };
382
383 if should_update {
384 self.current = Some(value.clone());
385 }
386 }
387 }
388
389 Ok(())
390 }
391
392 fn finalize(self: Box<Self>) -> DataValue {
393 self.current.unwrap_or(DataValue::Null)
394 }
395
396 fn clone_box(&self) -> Box<dyn AggregateState> {
397 Box::new(MinMaxState {
398 is_min: self.is_min,
399 current: self.current.clone(),
400 })
401 }
402
403 fn reset(&mut self) {
404 self.current = None;
405 }
406}
407
408struct StringAggFunction {
411 separator: String,
412}
413
414impl StringAggFunction {
415 fn new() -> Self {
416 Self {
417 separator: ",".to_string(), }
419 }
420
421 fn with_separator(separator: String) -> Self {
422 Self { separator }
423 }
424}
425
426impl AggregateFunction for StringAggFunction {
427 fn name(&self) -> &str {
428 "STRING_AGG"
429 }
430
431 fn description(&self) -> &str {
432 "Concatenate strings with a separator"
433 }
434
435 fn create_state(&self) -> Box<dyn AggregateState> {
436 Box::new(StringAggState {
437 values: Vec::new(),
438 separator: self.separator.clone(),
439 })
440 }
441
442 fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
443 if params.is_empty() {
445 return Ok(Box::new(StringAggFunction::new()));
446 }
447
448 let separator = match ¶ms[0] {
449 DataValue::String(s) => s.clone(),
450 DataValue::InternedString(s) => s.to_string(),
451 _ => return Err(anyhow!("STRING_AGG separator must be a string")),
452 };
453
454 Ok(Box::new(StringAggFunction::with_separator(separator)))
455 }
456}
457
458struct StringAggState {
459 values: Vec<String>,
460 separator: String,
461}
462
463impl AggregateState for StringAggState {
464 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
465 match value {
466 DataValue::Null => Ok(()), DataValue::String(s) => {
468 self.values.push(s.clone());
469 Ok(())
470 }
471 DataValue::InternedString(s) => {
472 self.values.push(s.to_string());
473 Ok(())
474 }
475 DataValue::Integer(n) => {
476 self.values.push(n.to_string());
477 Ok(())
478 }
479 DataValue::Float(f) => {
480 self.values.push(f.to_string());
481 Ok(())
482 }
483 DataValue::Boolean(b) => {
484 self.values.push(b.to_string());
485 Ok(())
486 }
487 DataValue::DateTime(dt) => {
488 self.values.push(dt.to_string());
489 Ok(())
490 }
491 }
492 }
493
494 fn finalize(self: Box<Self>) -> DataValue {
495 if self.values.is_empty() {
496 DataValue::Null
497 } else {
498 DataValue::String(self.values.join(&self.separator))
499 }
500 }
501
502 fn clone_box(&self) -> Box<dyn AggregateState> {
503 Box::new(StringAggState {
504 values: self.values.clone(),
505 separator: self.separator.clone(),
506 })
507 }
508
509 fn reset(&mut self) {
510 self.values.clear();
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_registry_creation() {
520 let registry = AggregateFunctionRegistry::new();
521 assert!(registry.contains("COUNT"));
522 assert!(registry.contains("SUM"));
523 assert!(registry.contains("AVG"));
524 assert!(registry.contains("MIN"));
525 assert!(registry.contains("MAX"));
526 assert!(registry.contains("STRING_AGG"));
527 }
528
529 #[test]
530 fn test_count_aggregate() {
531 let func = CountFunction;
532 let mut state = func.create_state();
533
534 state.accumulate(&DataValue::Integer(1)).unwrap();
535 state.accumulate(&DataValue::Null).unwrap();
536 state.accumulate(&DataValue::Integer(3)).unwrap();
537
538 let result = state.finalize();
539 assert_eq!(result, DataValue::Integer(2));
540 }
541
542 #[test]
543 fn test_string_agg() {
544 let func = StringAggFunction::with_separator(", ".to_string());
545 let mut state = func.create_state();
546
547 state
548 .accumulate(&DataValue::String("apple".to_string()))
549 .unwrap();
550 state
551 .accumulate(&DataValue::String("banana".to_string()))
552 .unwrap();
553 state
554 .accumulate(&DataValue::String("cherry".to_string()))
555 .unwrap();
556
557 let result = state.finalize();
558 assert_eq!(
559 result,
560 DataValue::String("apple, banana, cherry".to_string())
561 );
562 }
563}