sql_cli/sql/aggregate_functions/
mod.rs1use anyhow::{anyhow, Result};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::data::datatable::DataValue;
10
11pub trait AggregateState: Send + Sync {
14 fn accumulate(&mut self, value: &DataValue) -> Result<()>;
16
17 fn finalize(self: Box<Self>) -> DataValue;
19
20 fn clone_box(&self) -> Box<dyn AggregateState>;
22
23 fn reset(&mut self);
25}
26
27pub trait AggregateFunction: Send + Sync {
30 fn name(&self) -> &str;
32
33 fn description(&self) -> &str;
35
36 fn create_state(&self) -> Box<dyn AggregateState>;
38
39 fn supports_distinct(&self) -> bool {
41 true }
43
44 fn set_parameters(&self, _params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
46 Ok(Box::new(DummyClone(self.name().to_string())))
48 }
49}
50
51struct DummyClone(String);
53impl AggregateFunction for DummyClone {
54 fn name(&self) -> &str {
55 &self.0
56 }
57 fn description(&self) -> &str {
58 ""
59 }
60 fn create_state(&self) -> Box<dyn AggregateState> {
61 panic!("DummyClone should not be used")
62 }
63}
64
65pub struct AggregateFunctionRegistry {
67 functions: HashMap<String, Arc<Box<dyn AggregateFunction>>>,
68}
69
70impl AggregateFunctionRegistry {
71 pub fn new() -> Self {
72 let mut registry = Self {
73 functions: HashMap::new(),
74 };
75 registry.register_builtin_functions();
76 registry
77 }
78
79 pub fn register(&mut self, function: Box<dyn AggregateFunction>) {
81 let name = function.name().to_uppercase();
82 self.functions.insert(name, Arc::new(function));
83 }
84
85 pub fn get(&self, name: &str) -> Option<Arc<Box<dyn AggregateFunction>>> {
87 self.functions.get(&name.to_uppercase()).cloned()
88 }
89
90 pub fn contains(&self, name: &str) -> bool {
92 self.functions.contains_key(&name.to_uppercase())
93 }
94
95 pub fn list_functions(&self) -> Vec<String> {
97 self.functions.keys().cloned().collect()
98 }
99
100 fn register_builtin_functions(&mut self) {
102 self.register(Box::new(CountFunction));
104 self.register(Box::new(CountStarFunction));
105 self.register(Box::new(SumFunction));
106 self.register(Box::new(AvgFunction));
107 self.register(Box::new(MinFunction));
108 self.register(Box::new(MaxFunction));
109
110 self.register(Box::new(StringAggFunction::new()));
112
113 self.register(Box::new(MedianFunction));
115 self.register(Box::new(ModeFunction));
116 self.register(Box::new(StdDevFunction));
117 self.register(Box::new(StdDevPFunction));
118 self.register(Box::new(VarianceFunction));
119 self.register(Box::new(VariancePFunction));
120 self.register(Box::new(PercentileFunction));
121 }
122}
123
124struct CountFunction;
127
128impl AggregateFunction for CountFunction {
129 fn name(&self) -> &str {
130 "COUNT"
131 }
132
133 fn description(&self) -> &str {
134 "Count the number of non-null values"
135 }
136
137 fn create_state(&self) -> Box<dyn AggregateState> {
138 Box::new(CountState { count: 0 })
139 }
140}
141
142struct CountState {
143 count: i64,
144}
145
146impl AggregateState for CountState {
147 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
148 if !matches!(value, DataValue::Null) {
150 self.count += 1;
151 }
152 Ok(())
153 }
154
155 fn finalize(self: Box<Self>) -> DataValue {
156 DataValue::Integer(self.count)
157 }
158
159 fn clone_box(&self) -> Box<dyn AggregateState> {
160 Box::new(CountState { count: self.count })
161 }
162
163 fn reset(&mut self) {
164 self.count = 0;
165 }
166}
167
168struct CountStarFunction;
170
171impl AggregateFunction for CountStarFunction {
172 fn name(&self) -> &str {
173 "COUNT_STAR"
174 }
175
176 fn description(&self) -> &str {
177 "Count all rows including nulls"
178 }
179
180 fn create_state(&self) -> Box<dyn AggregateState> {
181 Box::new(CountStarState { count: 0 })
182 }
183}
184
185struct CountStarState {
186 count: i64,
187}
188
189impl AggregateState for CountStarState {
190 fn accumulate(&mut self, _value: &DataValue) -> Result<()> {
191 self.count += 1;
193 Ok(())
194 }
195
196 fn finalize(self: Box<Self>) -> DataValue {
197 DataValue::Integer(self.count)
198 }
199
200 fn clone_box(&self) -> Box<dyn AggregateState> {
201 Box::new(CountStarState { count: self.count })
202 }
203
204 fn reset(&mut self) {
205 self.count = 0;
206 }
207}
208
209struct SumFunction;
212
213impl AggregateFunction for SumFunction {
214 fn name(&self) -> &str {
215 "SUM"
216 }
217
218 fn description(&self) -> &str {
219 "Calculate the sum of values"
220 }
221
222 fn create_state(&self) -> Box<dyn AggregateState> {
223 Box::new(SumState {
224 int_sum: None,
225 float_sum: None,
226 has_values: false,
227 })
228 }
229}
230
231struct SumState {
232 int_sum: Option<i64>,
233 float_sum: Option<f64>,
234 has_values: bool,
235}
236
237impl AggregateState for SumState {
238 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
239 match value {
240 DataValue::Null => Ok(()), DataValue::Integer(n) => {
242 self.has_values = true;
243 if let Some(ref mut sum) = self.int_sum {
244 *sum = sum.saturating_add(*n);
245 } else if let Some(ref mut fsum) = self.float_sum {
246 *fsum += *n as f64;
247 } else {
248 self.int_sum = Some(*n);
249 }
250 Ok(())
251 }
252 DataValue::Float(f) => {
253 self.has_values = true;
254 if let Some(isum) = self.int_sum.take() {
256 self.float_sum = Some(isum as f64 + f);
257 } else if let Some(ref mut fsum) = self.float_sum {
258 *fsum += f;
259 } else {
260 self.float_sum = Some(*f);
261 }
262 Ok(())
263 }
264 _ => Err(anyhow!("Cannot sum non-numeric value")),
265 }
266 }
267
268 fn finalize(self: Box<Self>) -> DataValue {
269 if !self.has_values {
270 return DataValue::Null;
271 }
272
273 if let Some(fsum) = self.float_sum {
274 DataValue::Float(fsum)
275 } else if let Some(isum) = self.int_sum {
276 DataValue::Integer(isum)
277 } else {
278 DataValue::Null
279 }
280 }
281
282 fn clone_box(&self) -> Box<dyn AggregateState> {
283 Box::new(SumState {
284 int_sum: self.int_sum,
285 float_sum: self.float_sum,
286 has_values: self.has_values,
287 })
288 }
289
290 fn reset(&mut self) {
291 self.int_sum = None;
292 self.float_sum = None;
293 self.has_values = false;
294 }
295}
296
297struct AvgFunction;
300
301impl AggregateFunction for AvgFunction {
302 fn name(&self) -> &str {
303 "AVG"
304 }
305
306 fn description(&self) -> &str {
307 "Calculate the average of values"
308 }
309
310 fn create_state(&self) -> Box<dyn AggregateState> {
311 Box::new(AvgState {
312 sum: SumState {
313 int_sum: None,
314 float_sum: None,
315 has_values: false,
316 },
317 count: 0,
318 })
319 }
320}
321
322struct AvgState {
323 sum: SumState,
324 count: i64,
325}
326
327impl AggregateState for AvgState {
328 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
329 if !matches!(value, DataValue::Null) {
330 self.sum.accumulate(value)?;
331 self.count += 1;
332 }
333 Ok(())
334 }
335
336 fn finalize(self: Box<Self>) -> DataValue {
337 if self.count == 0 {
338 return DataValue::Null;
339 }
340
341 let sum = Box::new(self.sum).finalize();
342 match sum {
343 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
344 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
345 _ => DataValue::Null,
346 }
347 }
348
349 fn clone_box(&self) -> Box<dyn AggregateState> {
350 Box::new(AvgState {
351 sum: SumState {
352 int_sum: self.sum.int_sum,
353 float_sum: self.sum.float_sum,
354 has_values: self.sum.has_values,
355 },
356 count: self.count,
357 })
358 }
359
360 fn reset(&mut self) {
361 self.sum.reset();
362 self.count = 0;
363 }
364}
365
366struct MinFunction;
369
370impl AggregateFunction for MinFunction {
371 fn name(&self) -> &str {
372 "MIN"
373 }
374
375 fn description(&self) -> &str {
376 "Find the minimum value"
377 }
378
379 fn create_state(&self) -> Box<dyn AggregateState> {
380 Box::new(MinMaxState {
381 is_min: true,
382 current: None,
383 })
384 }
385}
386
387struct MaxFunction;
390
391impl AggregateFunction for MaxFunction {
392 fn name(&self) -> &str {
393 "MAX"
394 }
395
396 fn description(&self) -> &str {
397 "Find the maximum value"
398 }
399
400 fn create_state(&self) -> Box<dyn AggregateState> {
401 Box::new(MinMaxState {
402 is_min: false,
403 current: None,
404 })
405 }
406}
407
408struct MinMaxState {
409 is_min: bool,
410 current: Option<DataValue>,
411}
412
413impl AggregateState for MinMaxState {
414 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
415 if matches!(value, DataValue::Null) {
416 return Ok(());
417 }
418
419 match &self.current {
420 None => {
421 self.current = Some(value.clone());
422 }
423 Some(current) => {
424 let should_update = if self.is_min {
425 value < current
426 } else {
427 value > current
428 };
429
430 if should_update {
431 self.current = Some(value.clone());
432 }
433 }
434 }
435
436 Ok(())
437 }
438
439 fn finalize(self: Box<Self>) -> DataValue {
440 self.current.unwrap_or(DataValue::Null)
441 }
442
443 fn clone_box(&self) -> Box<dyn AggregateState> {
444 Box::new(MinMaxState {
445 is_min: self.is_min,
446 current: self.current.clone(),
447 })
448 }
449
450 fn reset(&mut self) {
451 self.current = None;
452 }
453}
454
455struct StringAggFunction {
458 separator: String,
459}
460
461impl StringAggFunction {
462 fn new() -> Self {
463 Self {
464 separator: ",".to_string(), }
466 }
467
468 fn with_separator(separator: String) -> Self {
469 Self { separator }
470 }
471}
472
473impl AggregateFunction for StringAggFunction {
474 fn name(&self) -> &str {
475 "STRING_AGG"
476 }
477
478 fn description(&self) -> &str {
479 "Concatenate strings with a separator"
480 }
481
482 fn create_state(&self) -> Box<dyn AggregateState> {
483 Box::new(StringAggState {
484 values: Vec::new(),
485 separator: self.separator.clone(),
486 })
487 }
488
489 fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
490 if params.is_empty() {
492 return Ok(Box::new(StringAggFunction::new()));
493 }
494
495 let separator = match ¶ms[0] {
496 DataValue::String(s) => s.clone(),
497 DataValue::InternedString(s) => s.to_string(),
498 _ => return Err(anyhow!("STRING_AGG separator must be a string")),
499 };
500
501 Ok(Box::new(StringAggFunction::with_separator(separator)))
502 }
503}
504
505struct StringAggState {
506 values: Vec<String>,
507 separator: String,
508}
509
510impl AggregateState for StringAggState {
511 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
512 match value {
513 DataValue::Null => Ok(()), DataValue::String(s) => {
515 self.values.push(s.clone());
516 Ok(())
517 }
518 DataValue::InternedString(s) => {
519 self.values.push(s.to_string());
520 Ok(())
521 }
522 DataValue::Integer(n) => {
523 self.values.push(n.to_string());
524 Ok(())
525 }
526 DataValue::Float(f) => {
527 self.values.push(f.to_string());
528 Ok(())
529 }
530 DataValue::Boolean(b) => {
531 self.values.push(b.to_string());
532 Ok(())
533 }
534 DataValue::DateTime(dt) => {
535 self.values.push(dt.to_string());
536 Ok(())
537 }
538 }
539 }
540
541 fn finalize(self: Box<Self>) -> DataValue {
542 if self.values.is_empty() {
543 DataValue::Null
544 } else {
545 DataValue::String(self.values.join(&self.separator))
546 }
547 }
548
549 fn clone_box(&self) -> Box<dyn AggregateState> {
550 Box::new(StringAggState {
551 values: self.values.clone(),
552 separator: self.separator.clone(),
553 })
554 }
555
556 fn reset(&mut self) {
557 self.values.clear();
558 }
559}
560
561struct MedianFunction;
564
565impl AggregateFunction for MedianFunction {
566 fn name(&self) -> &str {
567 "MEDIAN"
568 }
569
570 fn description(&self) -> &str {
571 "Calculate the median (middle value) of numeric values"
572 }
573
574 fn create_state(&self) -> Box<dyn AggregateState> {
575 Box::new(CollectorState {
576 values: Vec::new(),
577 function_type: CollectorFunction::Median,
578 })
579 }
580}
581
582struct ModeFunction;
585
586impl AggregateFunction for ModeFunction {
587 fn name(&self) -> &str {
588 "MODE"
589 }
590
591 fn description(&self) -> &str {
592 "Find the most frequently occurring value"
593 }
594
595 fn create_state(&self) -> Box<dyn AggregateState> {
596 Box::new(CollectorState {
597 values: Vec::new(),
598 function_type: CollectorFunction::Mode,
599 })
600 }
601}
602
603struct StdDevFunction;
606
607impl AggregateFunction for StdDevFunction {
608 fn name(&self) -> &str {
609 "STDDEV"
610 }
611
612 fn description(&self) -> &str {
613 "Calculate the sample standard deviation"
614 }
615
616 fn create_state(&self) -> Box<dyn AggregateState> {
617 Box::new(CollectorState {
618 values: Vec::new(),
619 function_type: CollectorFunction::StdDev,
620 })
621 }
622}
623
624struct StdDevPFunction;
627
628impl AggregateFunction for StdDevPFunction {
629 fn name(&self) -> &str {
630 "STDDEV_POP"
631 }
632
633 fn description(&self) -> &str {
634 "Calculate the population standard deviation"
635 }
636
637 fn create_state(&self) -> Box<dyn AggregateState> {
638 Box::new(CollectorState {
639 values: Vec::new(),
640 function_type: CollectorFunction::StdDevP,
641 })
642 }
643}
644
645struct VarianceFunction;
648
649impl AggregateFunction for VarianceFunction {
650 fn name(&self) -> &str {
651 "VARIANCE"
652 }
653
654 fn description(&self) -> &str {
655 "Calculate the sample variance"
656 }
657
658 fn create_state(&self) -> Box<dyn AggregateState> {
659 Box::new(CollectorState {
660 values: Vec::new(),
661 function_type: CollectorFunction::Variance,
662 })
663 }
664}
665
666struct VariancePFunction;
669
670impl AggregateFunction for VariancePFunction {
671 fn name(&self) -> &str {
672 "VARIANCE_POP"
673 }
674
675 fn description(&self) -> &str {
676 "Calculate the population variance"
677 }
678
679 fn create_state(&self) -> Box<dyn AggregateState> {
680 Box::new(CollectorState {
681 values: Vec::new(),
682 function_type: CollectorFunction::VarianceP,
683 })
684 }
685}
686
687struct PercentileFunction;
690
691impl AggregateFunction for PercentileFunction {
692 fn name(&self) -> &str {
693 "PERCENTILE"
694 }
695
696 fn description(&self) -> &str {
697 "Calculate the nth percentile of values"
698 }
699
700 fn create_state(&self) -> Box<dyn AggregateState> {
701 Box::new(PercentileState {
702 values: Vec::new(),
703 percentile: 50.0, })
705 }
706
707 fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
708 if params.is_empty() {
710 return Ok(Box::new(PercentileFunction));
711 }
712
713 let percentile = match ¶ms[0] {
714 DataValue::Integer(i) => *i as f64,
715 DataValue::Float(f) => *f,
716 _ => {
717 return Err(anyhow!(
718 "PERCENTILE parameter must be a number between 0 and 100"
719 ))
720 }
721 };
722
723 if percentile < 0.0 || percentile > 100.0 {
724 return Err(anyhow!("PERCENTILE must be between 0 and 100"));
725 }
726
727 Ok(Box::new(PercentileWithParam { percentile }))
728 }
729}
730
731struct PercentileWithParam {
732 percentile: f64,
733}
734
735impl AggregateFunction for PercentileWithParam {
736 fn name(&self) -> &str {
737 "PERCENTILE"
738 }
739
740 fn description(&self) -> &str {
741 "Calculate the nth percentile of values"
742 }
743
744 fn create_state(&self) -> Box<dyn AggregateState> {
745 Box::new(PercentileState {
746 values: Vec::new(),
747 percentile: self.percentile,
748 })
749 }
750}
751
752enum CollectorFunction {
755 Median,
756 Mode,
757 StdDev, StdDevP, Variance, VarianceP, }
762
763struct CollectorState {
764 values: Vec<f64>,
765 function_type: CollectorFunction,
766}
767
768impl AggregateState for CollectorState {
769 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
770 match value {
771 DataValue::Null => Ok(()), DataValue::Integer(n) => {
773 self.values.push(*n as f64);
774 Ok(())
775 }
776 DataValue::Float(f) => {
777 self.values.push(*f);
778 Ok(())
779 }
780 _ => match self.function_type {
781 CollectorFunction::Mode => {
782 Err(anyhow!("MODE currently only supports numeric values"))
784 }
785 _ => Err(anyhow!("Statistical functions require numeric values")),
786 },
787 }
788 }
789
790 fn finalize(self: Box<Self>) -> DataValue {
791 if self.values.is_empty() {
792 return DataValue::Null;
793 }
794
795 match self.function_type {
796 CollectorFunction::Median => {
797 let mut sorted = self.values.clone();
798 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
799 let len = sorted.len();
800 if len % 2 == 0 {
801 DataValue::Float((sorted[len / 2 - 1] + sorted[len / 2]) / 2.0)
802 } else {
803 DataValue::Float(sorted[len / 2])
804 }
805 }
806 CollectorFunction::Mode => {
807 use std::collections::HashMap;
808 let mut counts = HashMap::new();
809 for value in &self.values {
810 *counts.entry(value.to_bits()).or_insert(0) += 1;
811 }
812 if let Some((bits, _)) = counts.iter().max_by_key(|&(_, count)| count) {
813 DataValue::Float(f64::from_bits(*bits))
814 } else {
815 DataValue::Null
816 }
817 }
818 CollectorFunction::StdDev | CollectorFunction::Variance => {
819 if self.values.len() < 2 {
821 return DataValue::Null;
822 }
823 let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
824 let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
825 / (self.values.len() - 1) as f64; match self.function_type {
828 CollectorFunction::StdDev => DataValue::Float(variance.sqrt()),
829 CollectorFunction::Variance => DataValue::Float(variance),
830 _ => unreachable!(),
831 }
832 }
833 CollectorFunction::StdDevP | CollectorFunction::VarianceP => {
834 let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
836 let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
837 / self.values.len() as f64; match self.function_type {
840 CollectorFunction::StdDevP => DataValue::Float(variance.sqrt()),
841 CollectorFunction::VarianceP => DataValue::Float(variance),
842 _ => unreachable!(),
843 }
844 }
845 }
846 }
847
848 fn clone_box(&self) -> Box<dyn AggregateState> {
849 Box::new(CollectorState {
850 values: self.values.clone(),
851 function_type: match self.function_type {
852 CollectorFunction::Median => CollectorFunction::Median,
853 CollectorFunction::Mode => CollectorFunction::Mode,
854 CollectorFunction::StdDev => CollectorFunction::StdDev,
855 CollectorFunction::StdDevP => CollectorFunction::StdDevP,
856 CollectorFunction::Variance => CollectorFunction::Variance,
857 CollectorFunction::VarianceP => CollectorFunction::VarianceP,
858 },
859 })
860 }
861
862 fn reset(&mut self) {
863 self.values.clear();
864 }
865}
866
867struct PercentileState {
870 values: Vec<f64>,
871 percentile: f64,
872}
873
874impl AggregateState for PercentileState {
875 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
876 match value {
877 DataValue::Null => Ok(()), DataValue::Integer(n) => {
879 self.values.push(*n as f64);
880 Ok(())
881 }
882 DataValue::Float(f) => {
883 self.values.push(*f);
884 Ok(())
885 }
886 _ => Err(anyhow!("PERCENTILE requires numeric values")),
887 }
888 }
889
890 fn finalize(self: Box<Self>) -> DataValue {
891 if self.values.is_empty() {
892 return DataValue::Null;
893 }
894
895 let mut sorted = self.values.clone();
896 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
897
898 let position = (self.percentile / 100.0) * (sorted.len() - 1) as f64;
900 let lower = position.floor() as usize;
901 let upper = position.ceil() as usize;
902
903 if lower == upper {
904 DataValue::Float(sorted[lower])
905 } else {
906 let weight = position - lower as f64;
908 DataValue::Float(sorted[lower] * (1.0 - weight) + sorted[upper] * weight)
909 }
910 }
911
912 fn clone_box(&self) -> Box<dyn AggregateState> {
913 Box::new(PercentileState {
914 values: self.values.clone(),
915 percentile: self.percentile,
916 })
917 }
918
919 fn reset(&mut self) {
920 self.values.clear();
921 }
922}
923
924#[cfg(test)]
925mod tests {
926 use super::*;
927
928 #[test]
929 fn test_registry_creation() {
930 let registry = AggregateFunctionRegistry::new();
931 assert!(registry.contains("COUNT"));
932 assert!(registry.contains("SUM"));
933 assert!(registry.contains("AVG"));
934 assert!(registry.contains("MIN"));
935 assert!(registry.contains("MAX"));
936 assert!(registry.contains("STRING_AGG"));
937 }
938
939 #[test]
940 fn test_count_aggregate() {
941 let func = CountFunction;
942 let mut state = func.create_state();
943
944 state.accumulate(&DataValue::Integer(1)).unwrap();
945 state.accumulate(&DataValue::Null).unwrap();
946 state.accumulate(&DataValue::Integer(3)).unwrap();
947
948 let result = state.finalize();
949 assert_eq!(result, DataValue::Integer(2));
950 }
951
952 #[test]
953 fn test_string_agg() {
954 let func = StringAggFunction::with_separator(", ".to_string());
955 let mut state = func.create_state();
956
957 state
958 .accumulate(&DataValue::String("apple".to_string()))
959 .unwrap();
960 state
961 .accumulate(&DataValue::String("banana".to_string()))
962 .unwrap();
963 state
964 .accumulate(&DataValue::String("cherry".to_string()))
965 .unwrap();
966
967 let result = state.finalize();
968 assert_eq!(
969 result,
970 DataValue::String("apple, banana, cherry".to_string())
971 );
972 }
973}