sql_cli/sql/aggregate_functions/
mod.rs1use anyhow::{anyhow, Result};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::data::datatable::DataValue;
10
11pub trait AggregateState: Send + Sync {
14 fn accumulate(&mut self, value: &DataValue) -> Result<()>;
16
17 fn finalize(self: Box<Self>) -> DataValue;
19
20 fn clone_box(&self) -> Box<dyn AggregateState>;
22
23 fn reset(&mut self);
25}
26
27pub trait AggregateFunction: Send + Sync {
30 fn name(&self) -> &str;
32
33 fn description(&self) -> &str;
35
36 fn create_state(&self) -> Box<dyn AggregateState>;
38
39 fn supports_distinct(&self) -> bool {
41 true }
43
44 fn set_parameters(&self, _params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
46 Ok(Box::new(DummyClone(self.name().to_string())))
48 }
49}
50
51struct DummyClone(String);
53impl AggregateFunction for DummyClone {
54 fn name(&self) -> &str {
55 &self.0
56 }
57 fn description(&self) -> &str {
58 ""
59 }
60 fn create_state(&self) -> Box<dyn AggregateState> {
61 panic!("DummyClone should not be used")
62 }
63}
64
65pub struct AggregateFunctionRegistry {
67 functions: HashMap<String, Arc<Box<dyn AggregateFunction>>>,
68}
69
70impl AggregateFunctionRegistry {
71 pub fn new() -> Self {
72 let mut registry = Self {
73 functions: HashMap::new(),
74 };
75 registry.register_builtin_functions();
76 registry
77 }
78
79 pub fn register(&mut self, function: Box<dyn AggregateFunction>) {
81 let name = function.name().to_uppercase();
82 self.functions.insert(name, Arc::new(function));
83 }
84
85 pub fn get(&self, name: &str) -> Option<Arc<Box<dyn AggregateFunction>>> {
87 self.functions.get(&name.to_uppercase()).cloned()
88 }
89
90 pub fn contains(&self, name: &str) -> bool {
92 self.functions.contains_key(&name.to_uppercase())
93 }
94
95 pub fn list_functions(&self) -> Vec<String> {
97 self.functions.keys().cloned().collect()
98 }
99
100 fn register_builtin_functions(&mut self) {
102 self.register(Box::new(CountFunction));
104 self.register(Box::new(CountStarFunction));
105 self.register(Box::new(SumFunction));
106 self.register(Box::new(AvgFunction));
107 self.register(Box::new(MinFunction));
108 self.register(Box::new(MaxFunction));
109
110 self.register(Box::new(StringAggFunction::new()));
112
113 self.register(Box::new(MedianFunction));
115 self.register(Box::new(ModeFunction));
116 self.register(Box::new(StdDevFunction));
117 self.register(Box::new(StdDevPFunction));
118 self.register(Box::new(VarianceFunction));
119 self.register(Box::new(VariancePFunction));
120 self.register(Box::new(PercentileFunction));
121 }
122}
123
124struct CountFunction;
127
128impl AggregateFunction for CountFunction {
129 fn name(&self) -> &str {
130 "COUNT"
131 }
132
133 fn description(&self) -> &str {
134 "Count the number of non-null values"
135 }
136
137 fn create_state(&self) -> Box<dyn AggregateState> {
138 Box::new(CountState { count: 0 })
139 }
140}
141
142struct CountState {
143 count: i64,
144}
145
146impl AggregateState for CountState {
147 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
148 if !matches!(value, DataValue::Null) {
150 self.count += 1;
151 }
152 Ok(())
153 }
154
155 fn finalize(self: Box<Self>) -> DataValue {
156 DataValue::Integer(self.count)
157 }
158
159 fn clone_box(&self) -> Box<dyn AggregateState> {
160 Box::new(CountState { count: self.count })
161 }
162
163 fn reset(&mut self) {
164 self.count = 0;
165 }
166}
167
168struct CountStarFunction;
170
171impl AggregateFunction for CountStarFunction {
172 fn name(&self) -> &str {
173 "COUNT_STAR"
174 }
175
176 fn description(&self) -> &str {
177 "Count all rows including nulls"
178 }
179
180 fn create_state(&self) -> Box<dyn AggregateState> {
181 Box::new(CountStarState { count: 0 })
182 }
183}
184
185struct CountStarState {
186 count: i64,
187}
188
189impl AggregateState for CountStarState {
190 fn accumulate(&mut self, _value: &DataValue) -> Result<()> {
191 self.count += 1;
193 Ok(())
194 }
195
196 fn finalize(self: Box<Self>) -> DataValue {
197 DataValue::Integer(self.count)
198 }
199
200 fn clone_box(&self) -> Box<dyn AggregateState> {
201 Box::new(CountStarState { count: self.count })
202 }
203
204 fn reset(&mut self) {
205 self.count = 0;
206 }
207}
208
209struct SumFunction;
212
213impl AggregateFunction for SumFunction {
214 fn name(&self) -> &str {
215 "SUM"
216 }
217
218 fn description(&self) -> &str {
219 "Calculate the sum of values"
220 }
221
222 fn create_state(&self) -> Box<dyn AggregateState> {
223 Box::new(SumState {
224 int_sum: None,
225 float_sum: None,
226 has_values: false,
227 })
228 }
229}
230
231struct SumState {
232 int_sum: Option<i64>,
233 float_sum: Option<f64>,
234 has_values: bool,
235}
236
237impl AggregateState for SumState {
238 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
239 match value {
240 DataValue::Null => Ok(()), DataValue::Integer(n) => {
242 self.has_values = true;
243 if let Some(ref mut sum) = self.int_sum {
244 *sum = sum.saturating_add(*n);
245 } else if let Some(ref mut fsum) = self.float_sum {
246 *fsum += *n as f64;
247 } else {
248 self.int_sum = Some(*n);
249 }
250 Ok(())
251 }
252 DataValue::Float(f) => {
253 self.has_values = true;
254 if let Some(isum) = self.int_sum.take() {
256 self.float_sum = Some(isum as f64 + f);
257 } else if let Some(ref mut fsum) = self.float_sum {
258 *fsum += f;
259 } else {
260 self.float_sum = Some(*f);
261 }
262 Ok(())
263 }
264 _ => Err(anyhow!("Cannot sum non-numeric value")),
265 }
266 }
267
268 fn finalize(self: Box<Self>) -> DataValue {
269 if !self.has_values {
270 return DataValue::Null;
271 }
272
273 if let Some(fsum) = self.float_sum {
274 DataValue::Float(fsum)
275 } else if let Some(isum) = self.int_sum {
276 DataValue::Integer(isum)
277 } else {
278 DataValue::Null
279 }
280 }
281
282 fn clone_box(&self) -> Box<dyn AggregateState> {
283 Box::new(SumState {
284 int_sum: self.int_sum,
285 float_sum: self.float_sum,
286 has_values: self.has_values,
287 })
288 }
289
290 fn reset(&mut self) {
291 self.int_sum = None;
292 self.float_sum = None;
293 self.has_values = false;
294 }
295}
296
297struct AvgFunction;
300
301impl AggregateFunction for AvgFunction {
302 fn name(&self) -> &str {
303 "AVG"
304 }
305
306 fn description(&self) -> &str {
307 "Calculate the average of values"
308 }
309
310 fn create_state(&self) -> Box<dyn AggregateState> {
311 Box::new(AvgState {
312 sum: SumState {
313 int_sum: None,
314 float_sum: None,
315 has_values: false,
316 },
317 count: 0,
318 })
319 }
320}
321
322struct AvgState {
323 sum: SumState,
324 count: i64,
325}
326
327impl AggregateState for AvgState {
328 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
329 if !matches!(value, DataValue::Null) {
330 self.sum.accumulate(value)?;
331 self.count += 1;
332 }
333 Ok(())
334 }
335
336 fn finalize(self: Box<Self>) -> DataValue {
337 if self.count == 0 {
338 return DataValue::Null;
339 }
340
341 let sum = Box::new(self.sum).finalize();
342 match sum {
343 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
344 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
345 _ => DataValue::Null,
346 }
347 }
348
349 fn clone_box(&self) -> Box<dyn AggregateState> {
350 Box::new(AvgState {
351 sum: SumState {
352 int_sum: self.sum.int_sum,
353 float_sum: self.sum.float_sum,
354 has_values: self.sum.has_values,
355 },
356 count: self.count,
357 })
358 }
359
360 fn reset(&mut self) {
361 self.sum.reset();
362 self.count = 0;
363 }
364}
365
366struct MinFunction;
369
370impl AggregateFunction for MinFunction {
371 fn name(&self) -> &str {
372 "MIN"
373 }
374
375 fn description(&self) -> &str {
376 "Find the minimum value"
377 }
378
379 fn create_state(&self) -> Box<dyn AggregateState> {
380 Box::new(MinMaxState {
381 is_min: true,
382 current: None,
383 })
384 }
385}
386
387struct MaxFunction;
390
391impl AggregateFunction for MaxFunction {
392 fn name(&self) -> &str {
393 "MAX"
394 }
395
396 fn description(&self) -> &str {
397 "Find the maximum value"
398 }
399
400 fn create_state(&self) -> Box<dyn AggregateState> {
401 Box::new(MinMaxState {
402 is_min: false,
403 current: None,
404 })
405 }
406}
407
408struct MinMaxState {
409 is_min: bool,
410 current: Option<DataValue>,
411}
412
413impl AggregateState for MinMaxState {
414 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
415 if matches!(value, DataValue::Null) {
416 return Ok(());
417 }
418
419 match &self.current {
420 None => {
421 self.current = Some(value.clone());
422 }
423 Some(current) => {
424 let should_update = if self.is_min {
425 value < current
426 } else {
427 value > current
428 };
429
430 if should_update {
431 self.current = Some(value.clone());
432 }
433 }
434 }
435
436 Ok(())
437 }
438
439 fn finalize(self: Box<Self>) -> DataValue {
440 self.current.unwrap_or(DataValue::Null)
441 }
442
443 fn clone_box(&self) -> Box<dyn AggregateState> {
444 Box::new(MinMaxState {
445 is_min: self.is_min,
446 current: self.current.clone(),
447 })
448 }
449
450 fn reset(&mut self) {
451 self.current = None;
452 }
453}
454
455struct StringAggFunction {
458 separator: String,
459}
460
461impl StringAggFunction {
462 fn new() -> Self {
463 Self {
464 separator: ",".to_string(), }
466 }
467
468 fn with_separator(separator: String) -> Self {
469 Self { separator }
470 }
471}
472
473impl AggregateFunction for StringAggFunction {
474 fn name(&self) -> &str {
475 "STRING_AGG"
476 }
477
478 fn description(&self) -> &str {
479 "Concatenate strings with a separator"
480 }
481
482 fn create_state(&self) -> Box<dyn AggregateState> {
483 Box::new(StringAggState {
484 values: Vec::new(),
485 separator: self.separator.clone(),
486 })
487 }
488
489 fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
490 if params.is_empty() {
492 return Ok(Box::new(StringAggFunction::new()));
493 }
494
495 let separator = match ¶ms[0] {
496 DataValue::String(s) => s.clone(),
497 DataValue::InternedString(s) => s.to_string(),
498 _ => return Err(anyhow!("STRING_AGG separator must be a string")),
499 };
500
501 Ok(Box::new(StringAggFunction::with_separator(separator)))
502 }
503}
504
505struct StringAggState {
506 values: Vec<String>,
507 separator: String,
508}
509
510impl AggregateState for StringAggState {
511 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
512 match value {
513 DataValue::Null => Ok(()), DataValue::String(s) => {
515 self.values.push(s.clone());
516 Ok(())
517 }
518 DataValue::InternedString(s) => {
519 self.values.push(s.to_string());
520 Ok(())
521 }
522 DataValue::Integer(n) => {
523 self.values.push(n.to_string());
524 Ok(())
525 }
526 DataValue::Float(f) => {
527 self.values.push(f.to_string());
528 Ok(())
529 }
530 DataValue::Boolean(b) => {
531 self.values.push(b.to_string());
532 Ok(())
533 }
534 DataValue::DateTime(dt) => {
535 self.values.push(dt.to_string());
536 Ok(())
537 }
538 DataValue::Vector(v) => {
539 let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
540 self.values.push(format!("[{}]", components.join(",")));
541 Ok(())
542 }
543 }
544 }
545
546 fn finalize(self: Box<Self>) -> DataValue {
547 if self.values.is_empty() {
548 DataValue::Null
549 } else {
550 DataValue::String(self.values.join(&self.separator))
551 }
552 }
553
554 fn clone_box(&self) -> Box<dyn AggregateState> {
555 Box::new(StringAggState {
556 values: self.values.clone(),
557 separator: self.separator.clone(),
558 })
559 }
560
561 fn reset(&mut self) {
562 self.values.clear();
563 }
564}
565
566struct MedianFunction;
569
570impl AggregateFunction for MedianFunction {
571 fn name(&self) -> &str {
572 "MEDIAN"
573 }
574
575 fn description(&self) -> &str {
576 "Calculate the median (middle value) of numeric values"
577 }
578
579 fn create_state(&self) -> Box<dyn AggregateState> {
580 Box::new(CollectorState {
581 values: Vec::new(),
582 function_type: CollectorFunction::Median,
583 })
584 }
585}
586
587struct ModeFunction;
590
591impl AggregateFunction for ModeFunction {
592 fn name(&self) -> &str {
593 "MODE"
594 }
595
596 fn description(&self) -> &str {
597 "Find the most frequently occurring value"
598 }
599
600 fn create_state(&self) -> Box<dyn AggregateState> {
601 Box::new(CollectorState {
602 values: Vec::new(),
603 function_type: CollectorFunction::Mode,
604 })
605 }
606}
607
608struct StdDevFunction;
611
612impl AggregateFunction for StdDevFunction {
613 fn name(&self) -> &str {
614 "STDDEV"
615 }
616
617 fn description(&self) -> &str {
618 "Calculate the sample standard deviation"
619 }
620
621 fn create_state(&self) -> Box<dyn AggregateState> {
622 Box::new(CollectorState {
623 values: Vec::new(),
624 function_type: CollectorFunction::StdDev,
625 })
626 }
627}
628
629struct StdDevPFunction;
632
633impl AggregateFunction for StdDevPFunction {
634 fn name(&self) -> &str {
635 "STDDEV_POP"
636 }
637
638 fn description(&self) -> &str {
639 "Calculate the population standard deviation"
640 }
641
642 fn create_state(&self) -> Box<dyn AggregateState> {
643 Box::new(CollectorState {
644 values: Vec::new(),
645 function_type: CollectorFunction::StdDevP,
646 })
647 }
648}
649
650struct VarianceFunction;
653
654impl AggregateFunction for VarianceFunction {
655 fn name(&self) -> &str {
656 "VARIANCE"
657 }
658
659 fn description(&self) -> &str {
660 "Calculate the sample variance"
661 }
662
663 fn create_state(&self) -> Box<dyn AggregateState> {
664 Box::new(CollectorState {
665 values: Vec::new(),
666 function_type: CollectorFunction::Variance,
667 })
668 }
669}
670
671struct VariancePFunction;
674
675impl AggregateFunction for VariancePFunction {
676 fn name(&self) -> &str {
677 "VARIANCE_POP"
678 }
679
680 fn description(&self) -> &str {
681 "Calculate the population variance"
682 }
683
684 fn create_state(&self) -> Box<dyn AggregateState> {
685 Box::new(CollectorState {
686 values: Vec::new(),
687 function_type: CollectorFunction::VarianceP,
688 })
689 }
690}
691
692struct PercentileFunction;
695
696impl AggregateFunction for PercentileFunction {
697 fn name(&self) -> &str {
698 "PERCENTILE"
699 }
700
701 fn description(&self) -> &str {
702 "Calculate the nth percentile of values"
703 }
704
705 fn create_state(&self) -> Box<dyn AggregateState> {
706 Box::new(PercentileState {
707 values: Vec::new(),
708 percentile: 50.0, })
710 }
711
712 fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
713 if params.is_empty() {
715 return Ok(Box::new(PercentileFunction));
716 }
717
718 let percentile = match ¶ms[0] {
719 DataValue::Integer(i) => *i as f64,
720 DataValue::Float(f) => *f,
721 _ => {
722 return Err(anyhow!(
723 "PERCENTILE parameter must be a number between 0 and 100"
724 ))
725 }
726 };
727
728 if percentile < 0.0 || percentile > 100.0 {
729 return Err(anyhow!("PERCENTILE must be between 0 and 100"));
730 }
731
732 Ok(Box::new(PercentileWithParam { percentile }))
733 }
734}
735
736struct PercentileWithParam {
737 percentile: f64,
738}
739
740impl AggregateFunction for PercentileWithParam {
741 fn name(&self) -> &str {
742 "PERCENTILE"
743 }
744
745 fn description(&self) -> &str {
746 "Calculate the nth percentile of values"
747 }
748
749 fn create_state(&self) -> Box<dyn AggregateState> {
750 Box::new(PercentileState {
751 values: Vec::new(),
752 percentile: self.percentile,
753 })
754 }
755}
756
757enum CollectorFunction {
760 Median,
761 Mode,
762 StdDev, StdDevP, Variance, VarianceP, }
767
768struct CollectorState {
769 values: Vec<f64>,
770 function_type: CollectorFunction,
771}
772
773impl AggregateState for CollectorState {
774 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
775 match value {
776 DataValue::Null => Ok(()), DataValue::Integer(n) => {
778 self.values.push(*n as f64);
779 Ok(())
780 }
781 DataValue::Float(f) => {
782 self.values.push(*f);
783 Ok(())
784 }
785 _ => match self.function_type {
786 CollectorFunction::Mode => {
787 Err(anyhow!("MODE currently only supports numeric values"))
789 }
790 _ => Err(anyhow!("Statistical functions require numeric values")),
791 },
792 }
793 }
794
795 fn finalize(self: Box<Self>) -> DataValue {
796 if self.values.is_empty() {
797 return DataValue::Null;
798 }
799
800 match self.function_type {
801 CollectorFunction::Median => {
802 let mut sorted = self.values.clone();
803 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
804 let len = sorted.len();
805 if len % 2 == 0 {
806 DataValue::Float((sorted[len / 2 - 1] + sorted[len / 2]) / 2.0)
807 } else {
808 DataValue::Float(sorted[len / 2])
809 }
810 }
811 CollectorFunction::Mode => {
812 use std::collections::HashMap;
813 let mut counts = HashMap::new();
814 for value in &self.values {
815 *counts.entry(value.to_bits()).or_insert(0) += 1;
816 }
817 if let Some((bits, _)) = counts.iter().max_by_key(|&(_, count)| count) {
818 DataValue::Float(f64::from_bits(*bits))
819 } else {
820 DataValue::Null
821 }
822 }
823 CollectorFunction::StdDev | CollectorFunction::Variance => {
824 if self.values.len() < 2 {
826 return DataValue::Null;
827 }
828 let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
829 let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
830 / (self.values.len() - 1) as f64; match self.function_type {
833 CollectorFunction::StdDev => DataValue::Float(variance.sqrt()),
834 CollectorFunction::Variance => DataValue::Float(variance),
835 _ => unreachable!(),
836 }
837 }
838 CollectorFunction::StdDevP | CollectorFunction::VarianceP => {
839 let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
841 let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
842 / self.values.len() as f64; match self.function_type {
845 CollectorFunction::StdDevP => DataValue::Float(variance.sqrt()),
846 CollectorFunction::VarianceP => DataValue::Float(variance),
847 _ => unreachable!(),
848 }
849 }
850 }
851 }
852
853 fn clone_box(&self) -> Box<dyn AggregateState> {
854 Box::new(CollectorState {
855 values: self.values.clone(),
856 function_type: match self.function_type {
857 CollectorFunction::Median => CollectorFunction::Median,
858 CollectorFunction::Mode => CollectorFunction::Mode,
859 CollectorFunction::StdDev => CollectorFunction::StdDev,
860 CollectorFunction::StdDevP => CollectorFunction::StdDevP,
861 CollectorFunction::Variance => CollectorFunction::Variance,
862 CollectorFunction::VarianceP => CollectorFunction::VarianceP,
863 },
864 })
865 }
866
867 fn reset(&mut self) {
868 self.values.clear();
869 }
870}
871
872struct PercentileState {
875 values: Vec<f64>,
876 percentile: f64,
877}
878
879impl AggregateState for PercentileState {
880 fn accumulate(&mut self, value: &DataValue) -> Result<()> {
881 match value {
882 DataValue::Null => Ok(()), DataValue::Integer(n) => {
884 self.values.push(*n as f64);
885 Ok(())
886 }
887 DataValue::Float(f) => {
888 self.values.push(*f);
889 Ok(())
890 }
891 _ => Err(anyhow!("PERCENTILE requires numeric values")),
892 }
893 }
894
895 fn finalize(self: Box<Self>) -> DataValue {
896 if self.values.is_empty() {
897 return DataValue::Null;
898 }
899
900 let mut sorted = self.values.clone();
901 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
902
903 let position = (self.percentile / 100.0) * (sorted.len() - 1) as f64;
905 let lower = position.floor() as usize;
906 let upper = position.ceil() as usize;
907
908 if lower == upper {
909 DataValue::Float(sorted[lower])
910 } else {
911 let weight = position - lower as f64;
913 DataValue::Float(sorted[lower] * (1.0 - weight) + sorted[upper] * weight)
914 }
915 }
916
917 fn clone_box(&self) -> Box<dyn AggregateState> {
918 Box::new(PercentileState {
919 values: self.values.clone(),
920 percentile: self.percentile,
921 })
922 }
923
924 fn reset(&mut self) {
925 self.values.clear();
926 }
927}
928
929#[cfg(test)]
930mod tests {
931 use super::*;
932
933 #[test]
934 fn test_registry_creation() {
935 let registry = AggregateFunctionRegistry::new();
936 assert!(registry.contains("COUNT"));
937 assert!(registry.contains("SUM"));
938 assert!(registry.contains("AVG"));
939 assert!(registry.contains("MIN"));
940 assert!(registry.contains("MAX"));
941 assert!(registry.contains("STRING_AGG"));
942 }
943
944 #[test]
945 fn test_count_aggregate() {
946 let func = CountFunction;
947 let mut state = func.create_state();
948
949 state.accumulate(&DataValue::Integer(1)).unwrap();
950 state.accumulate(&DataValue::Null).unwrap();
951 state.accumulate(&DataValue::Integer(3)).unwrap();
952
953 let result = state.finalize();
954 assert_eq!(result, DataValue::Integer(2));
955 }
956
957 #[test]
958 fn test_string_agg() {
959 let func = StringAggFunction::with_separator(", ".to_string());
960 let mut state = func.create_state();
961
962 state
963 .accumulate(&DataValue::String("apple".to_string()))
964 .unwrap();
965 state
966 .accumulate(&DataValue::String("banana".to_string()))
967 .unwrap();
968 state
969 .accumulate(&DataValue::String("cherry".to_string()))
970 .unwrap();
971
972 let result = state.finalize();
973 assert_eq!(
974 result,
975 DataValue::String("apple, banana, cherry".to_string())
976 );
977 }
978}