Skip to main content

sql_cli/sql/aggregate_functions/
mod.rs

1// Aggregate Function Registry
2// Provides a clean API for group-based aggregate computations
3// Moves all aggregate logic out of the evaluator into a registry pattern
4
5use anyhow::{anyhow, Result};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::data::datatable::DataValue;
10
11/// State maintained during aggregation
12/// Each aggregate function manages its own state type
13pub trait AggregateState: Send + Sync {
14    /// Add a value to the aggregate
15    fn accumulate(&mut self, value: &DataValue) -> Result<()>;
16
17    /// Finalize and return the aggregate result
18    fn finalize(self: Box<Self>) -> DataValue;
19
20    /// Create a new instance of this state
21    fn clone_box(&self) -> Box<dyn AggregateState>;
22
23    /// Reset the state for reuse
24    fn reset(&mut self);
25}
26
27/// Aggregate function trait
28/// Each aggregate function (SUM, COUNT, AVG, etc.) implements this
29pub trait AggregateFunction: Send + Sync {
30    /// Function name (e.g., "SUM", "COUNT", "STRING_AGG")
31    fn name(&self) -> &str;
32
33    /// Description for help system
34    fn description(&self) -> &str;
35
36    /// Create initial state for this aggregate
37    fn create_state(&self) -> Box<dyn AggregateState>;
38
39    /// Does this aggregate support DISTINCT?
40    fn supports_distinct(&self) -> bool {
41        true // Most aggregates should support DISTINCT
42    }
43
44    /// For aggregates with parameters (like STRING_AGG separator)
45    fn set_parameters(&self, _params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
46        // Default implementation for aggregates without parameters
47        Ok(Box::new(DummyClone(self.name().to_string())))
48    }
49}
50
51// Dummy clone helper for default implementation
52struct DummyClone(String);
53impl AggregateFunction for DummyClone {
54    fn name(&self) -> &str {
55        &self.0
56    }
57    fn description(&self) -> &str {
58        ""
59    }
60    fn create_state(&self) -> Box<dyn AggregateState> {
61        panic!("DummyClone should not be used")
62    }
63}
64
65/// Registry for aggregate functions
66pub struct AggregateFunctionRegistry {
67    functions: HashMap<String, Arc<Box<dyn AggregateFunction>>>,
68}
69
70impl AggregateFunctionRegistry {
71    pub fn new() -> Self {
72        let mut registry = Self {
73            functions: HashMap::new(),
74        };
75        registry.register_builtin_functions();
76        registry
77    }
78
79    /// Register an aggregate function
80    pub fn register(&mut self, function: Box<dyn AggregateFunction>) {
81        let name = function.name().to_uppercase();
82        self.functions.insert(name, Arc::new(function));
83    }
84
85    /// Get an aggregate function by name
86    pub fn get(&self, name: &str) -> Option<Arc<Box<dyn AggregateFunction>>> {
87        self.functions.get(&name.to_uppercase()).cloned()
88    }
89
90    /// Check if a function exists
91    pub fn contains(&self, name: &str) -> bool {
92        self.functions.contains_key(&name.to_uppercase())
93    }
94
95    /// List all registered functions
96    pub fn list_functions(&self) -> Vec<String> {
97        self.functions.keys().cloned().collect()
98    }
99
100    /// Register built-in aggregate functions
101    fn register_builtin_functions(&mut self) {
102        // Basic aggregates
103        self.register(Box::new(CountFunction));
104        self.register(Box::new(CountStarFunction));
105        self.register(Box::new(SumFunction));
106        self.register(Box::new(AvgFunction));
107        self.register(Box::new(MinFunction));
108        self.register(Box::new(MaxFunction));
109
110        // String aggregates
111        self.register(Box::new(StringAggFunction::new()));
112
113        // Statistical aggregates
114        self.register(Box::new(MedianFunction));
115        self.register(Box::new(ModeFunction));
116        self.register(Box::new(StdDevFunction));
117        self.register(Box::new(StdDevPFunction));
118        self.register(Box::new(VarianceFunction));
119        self.register(Box::new(VariancePFunction));
120        self.register(Box::new(PercentileFunction));
121    }
122}
123
124// ============= COUNT Implementation =============
125
126struct CountFunction;
127
128impl AggregateFunction for CountFunction {
129    fn name(&self) -> &str {
130        "COUNT"
131    }
132
133    fn description(&self) -> &str {
134        "Count the number of non-null values"
135    }
136
137    fn create_state(&self) -> Box<dyn AggregateState> {
138        Box::new(CountState { count: 0 })
139    }
140}
141
142struct CountState {
143    count: i64,
144}
145
146impl AggregateState for CountState {
147    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
148        // COUNT(column) counts non-nulls only
149        if !matches!(value, DataValue::Null) {
150            self.count += 1;
151        }
152        Ok(())
153    }
154
155    fn finalize(self: Box<Self>) -> DataValue {
156        DataValue::Integer(self.count)
157    }
158
159    fn clone_box(&self) -> Box<dyn AggregateState> {
160        Box::new(CountState { count: self.count })
161    }
162
163    fn reset(&mut self) {
164        self.count = 0;
165    }
166}
167
168// COUNT(*) - counts all rows including nulls
169struct CountStarFunction;
170
171impl AggregateFunction for CountStarFunction {
172    fn name(&self) -> &str {
173        "COUNT_STAR"
174    }
175
176    fn description(&self) -> &str {
177        "Count all rows including nulls"
178    }
179
180    fn create_state(&self) -> Box<dyn AggregateState> {
181        Box::new(CountStarState { count: 0 })
182    }
183}
184
185struct CountStarState {
186    count: i64,
187}
188
189impl AggregateState for CountStarState {
190    fn accumulate(&mut self, _value: &DataValue) -> Result<()> {
191        // COUNT(*) counts all rows, even nulls
192        self.count += 1;
193        Ok(())
194    }
195
196    fn finalize(self: Box<Self>) -> DataValue {
197        DataValue::Integer(self.count)
198    }
199
200    fn clone_box(&self) -> Box<dyn AggregateState> {
201        Box::new(CountStarState { count: self.count })
202    }
203
204    fn reset(&mut self) {
205        self.count = 0;
206    }
207}
208
209// ============= SUM Implementation =============
210
211struct SumFunction;
212
213impl AggregateFunction for SumFunction {
214    fn name(&self) -> &str {
215        "SUM"
216    }
217
218    fn description(&self) -> &str {
219        "Calculate the sum of values"
220    }
221
222    fn create_state(&self) -> Box<dyn AggregateState> {
223        Box::new(SumState {
224            int_sum: None,
225            float_sum: None,
226            has_values: false,
227        })
228    }
229}
230
231struct SumState {
232    int_sum: Option<i64>,
233    float_sum: Option<f64>,
234    has_values: bool,
235}
236
237impl AggregateState for SumState {
238    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
239        match value {
240            DataValue::Null => Ok(()), // Skip nulls
241            DataValue::Integer(n) => {
242                self.has_values = true;
243                if let Some(ref mut sum) = self.int_sum {
244                    *sum = sum.saturating_add(*n);
245                } else if let Some(ref mut fsum) = self.float_sum {
246                    *fsum += *n as f64;
247                } else {
248                    self.int_sum = Some(*n);
249                }
250                Ok(())
251            }
252            DataValue::Float(f) => {
253                self.has_values = true;
254                // Once we have a float, convert everything to float
255                if let Some(isum) = self.int_sum.take() {
256                    self.float_sum = Some(isum as f64 + f);
257                } else if let Some(ref mut fsum) = self.float_sum {
258                    *fsum += f;
259                } else {
260                    self.float_sum = Some(*f);
261                }
262                Ok(())
263            }
264            DataValue::Boolean(b) => {
265                // Coerce boolean to integer: true=1, false=0
266                // Enables patterns like AVG(x > 5), SUM(col = 'value')
267                let n = if *b { 1i64 } else { 0i64 };
268                self.has_values = true;
269                if let Some(ref mut sum) = self.int_sum {
270                    *sum = sum.saturating_add(n);
271                } else if let Some(ref mut fsum) = self.float_sum {
272                    *fsum += n as f64;
273                } else {
274                    self.int_sum = Some(n);
275                }
276                Ok(())
277            }
278            _ => Err(anyhow!("Cannot sum non-numeric value")),
279        }
280    }
281
282    fn finalize(self: Box<Self>) -> DataValue {
283        if !self.has_values {
284            return DataValue::Null;
285        }
286
287        if let Some(fsum) = self.float_sum {
288            DataValue::Float(fsum)
289        } else if let Some(isum) = self.int_sum {
290            DataValue::Integer(isum)
291        } else {
292            DataValue::Null
293        }
294    }
295
296    fn clone_box(&self) -> Box<dyn AggregateState> {
297        Box::new(SumState {
298            int_sum: self.int_sum,
299            float_sum: self.float_sum,
300            has_values: self.has_values,
301        })
302    }
303
304    fn reset(&mut self) {
305        self.int_sum = None;
306        self.float_sum = None;
307        self.has_values = false;
308    }
309}
310
311// ============= AVG Implementation =============
312
313struct AvgFunction;
314
315impl AggregateFunction for AvgFunction {
316    fn name(&self) -> &str {
317        "AVG"
318    }
319
320    fn description(&self) -> &str {
321        "Calculate the average of values"
322    }
323
324    fn create_state(&self) -> Box<dyn AggregateState> {
325        Box::new(AvgState {
326            sum: SumState {
327                int_sum: None,
328                float_sum: None,
329                has_values: false,
330            },
331            count: 0,
332        })
333    }
334}
335
336struct AvgState {
337    sum: SumState,
338    count: i64,
339}
340
341impl AggregateState for AvgState {
342    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
343        if !matches!(value, DataValue::Null) {
344            self.sum.accumulate(value)?;
345            self.count += 1;
346        }
347        Ok(())
348    }
349
350    fn finalize(self: Box<Self>) -> DataValue {
351        if self.count == 0 {
352            return DataValue::Null;
353        }
354
355        let sum = Box::new(self.sum).finalize();
356        match sum {
357            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
358            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
359            _ => DataValue::Null,
360        }
361    }
362
363    fn clone_box(&self) -> Box<dyn AggregateState> {
364        Box::new(AvgState {
365            sum: SumState {
366                int_sum: self.sum.int_sum,
367                float_sum: self.sum.float_sum,
368                has_values: self.sum.has_values,
369            },
370            count: self.count,
371        })
372    }
373
374    fn reset(&mut self) {
375        self.sum.reset();
376        self.count = 0;
377    }
378}
379
380// ============= MIN Implementation =============
381
382struct MinFunction;
383
384impl AggregateFunction for MinFunction {
385    fn name(&self) -> &str {
386        "MIN"
387    }
388
389    fn description(&self) -> &str {
390        "Find the minimum value"
391    }
392
393    fn create_state(&self) -> Box<dyn AggregateState> {
394        Box::new(MinMaxState {
395            is_min: true,
396            current: None,
397        })
398    }
399}
400
401// ============= MAX Implementation =============
402
403struct MaxFunction;
404
405impl AggregateFunction for MaxFunction {
406    fn name(&self) -> &str {
407        "MAX"
408    }
409
410    fn description(&self) -> &str {
411        "Find the maximum value"
412    }
413
414    fn create_state(&self) -> Box<dyn AggregateState> {
415        Box::new(MinMaxState {
416            is_min: false,
417            current: None,
418        })
419    }
420}
421
422struct MinMaxState {
423    is_min: bool,
424    current: Option<DataValue>,
425}
426
427impl AggregateState for MinMaxState {
428    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
429        if matches!(value, DataValue::Null) {
430            return Ok(());
431        }
432
433        match &self.current {
434            None => {
435                self.current = Some(value.clone());
436            }
437            Some(current) => {
438                let should_update = if self.is_min {
439                    value < current
440                } else {
441                    value > current
442                };
443
444                if should_update {
445                    self.current = Some(value.clone());
446                }
447            }
448        }
449
450        Ok(())
451    }
452
453    fn finalize(self: Box<Self>) -> DataValue {
454        self.current.unwrap_or(DataValue::Null)
455    }
456
457    fn clone_box(&self) -> Box<dyn AggregateState> {
458        Box::new(MinMaxState {
459            is_min: self.is_min,
460            current: self.current.clone(),
461        })
462    }
463
464    fn reset(&mut self) {
465        self.current = None;
466    }
467}
468
469// ============= STRING_AGG Implementation =============
470
471struct StringAggFunction {
472    separator: String,
473}
474
475impl StringAggFunction {
476    fn new() -> Self {
477        Self {
478            separator: ",".to_string(), // Default separator
479        }
480    }
481
482    fn with_separator(separator: String) -> Self {
483        Self { separator }
484    }
485}
486
487impl AggregateFunction for StringAggFunction {
488    fn name(&self) -> &str {
489        "STRING_AGG"
490    }
491
492    fn description(&self) -> &str {
493        "Concatenate strings with a separator"
494    }
495
496    fn create_state(&self) -> Box<dyn AggregateState> {
497        Box::new(StringAggState {
498            values: Vec::new(),
499            separator: self.separator.clone(),
500        })
501    }
502
503    fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
504        // STRING_AGG takes a separator as second parameter
505        if params.is_empty() {
506            return Ok(Box::new(StringAggFunction::new()));
507        }
508
509        let separator = match &params[0] {
510            DataValue::String(s) => s.clone(),
511            DataValue::InternedString(s) => s.to_string(),
512            _ => return Err(anyhow!("STRING_AGG separator must be a string")),
513        };
514
515        Ok(Box::new(StringAggFunction::with_separator(separator)))
516    }
517}
518
519struct StringAggState {
520    values: Vec<String>,
521    separator: String,
522}
523
524impl AggregateState for StringAggState {
525    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
526        match value {
527            DataValue::Null => Ok(()), // Skip nulls
528            DataValue::String(s) => {
529                self.values.push(s.clone());
530                Ok(())
531            }
532            DataValue::InternedString(s) => {
533                self.values.push(s.to_string());
534                Ok(())
535            }
536            DataValue::Integer(n) => {
537                self.values.push(n.to_string());
538                Ok(())
539            }
540            DataValue::Float(f) => {
541                self.values.push(f.to_string());
542                Ok(())
543            }
544            DataValue::Boolean(b) => {
545                self.values.push(b.to_string());
546                Ok(())
547            }
548            DataValue::DateTime(dt) => {
549                self.values.push(dt.to_string());
550                Ok(())
551            }
552            DataValue::Vector(v) => {
553                let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
554                self.values.push(format!("[{}]", components.join(",")));
555                Ok(())
556            }
557        }
558    }
559
560    fn finalize(self: Box<Self>) -> DataValue {
561        if self.values.is_empty() {
562            DataValue::Null
563        } else {
564            DataValue::String(self.values.join(&self.separator))
565        }
566    }
567
568    fn clone_box(&self) -> Box<dyn AggregateState> {
569        Box::new(StringAggState {
570            values: self.values.clone(),
571            separator: self.separator.clone(),
572        })
573    }
574
575    fn reset(&mut self) {
576        self.values.clear();
577    }
578}
579
580// ============= MEDIAN Implementation =============
581
582struct MedianFunction;
583
584impl AggregateFunction for MedianFunction {
585    fn name(&self) -> &str {
586        "MEDIAN"
587    }
588
589    fn description(&self) -> &str {
590        "Calculate the median (middle value) of numeric values"
591    }
592
593    fn create_state(&self) -> Box<dyn AggregateState> {
594        Box::new(CollectorState {
595            values: Vec::new(),
596            function_type: CollectorFunction::Median,
597        })
598    }
599}
600
601// ============= MODE Implementation =============
602
603struct ModeFunction;
604
605impl AggregateFunction for ModeFunction {
606    fn name(&self) -> &str {
607        "MODE"
608    }
609
610    fn description(&self) -> &str {
611        "Find the most frequently occurring value"
612    }
613
614    fn create_state(&self) -> Box<dyn AggregateState> {
615        Box::new(CollectorState {
616            values: Vec::new(),
617            function_type: CollectorFunction::Mode,
618        })
619    }
620}
621
622// ============= STDDEV Implementation =============
623
624struct StdDevFunction;
625
626impl AggregateFunction for StdDevFunction {
627    fn name(&self) -> &str {
628        "STDDEV"
629    }
630
631    fn description(&self) -> &str {
632        "Calculate the sample standard deviation"
633    }
634
635    fn create_state(&self) -> Box<dyn AggregateState> {
636        Box::new(CollectorState {
637            values: Vec::new(),
638            function_type: CollectorFunction::StdDev,
639        })
640    }
641}
642
643// ============= STDDEV_POP Implementation =============
644
645struct StdDevPFunction;
646
647impl AggregateFunction for StdDevPFunction {
648    fn name(&self) -> &str {
649        "STDDEV_POP"
650    }
651
652    fn description(&self) -> &str {
653        "Calculate the population standard deviation"
654    }
655
656    fn create_state(&self) -> Box<dyn AggregateState> {
657        Box::new(CollectorState {
658            values: Vec::new(),
659            function_type: CollectorFunction::StdDevP,
660        })
661    }
662}
663
664// ============= VARIANCE Implementation =============
665
666struct VarianceFunction;
667
668impl AggregateFunction for VarianceFunction {
669    fn name(&self) -> &str {
670        "VARIANCE"
671    }
672
673    fn description(&self) -> &str {
674        "Calculate the sample variance"
675    }
676
677    fn create_state(&self) -> Box<dyn AggregateState> {
678        Box::new(CollectorState {
679            values: Vec::new(),
680            function_type: CollectorFunction::Variance,
681        })
682    }
683}
684
685// ============= VARIANCE_POP Implementation =============
686
687struct VariancePFunction;
688
689impl AggregateFunction for VariancePFunction {
690    fn name(&self) -> &str {
691        "VARIANCE_POP"
692    }
693
694    fn description(&self) -> &str {
695        "Calculate the population variance"
696    }
697
698    fn create_state(&self) -> Box<dyn AggregateState> {
699        Box::new(CollectorState {
700            values: Vec::new(),
701            function_type: CollectorFunction::VarianceP,
702        })
703    }
704}
705
706// ============= PERCENTILE Implementation =============
707
708struct PercentileFunction;
709
710impl AggregateFunction for PercentileFunction {
711    fn name(&self) -> &str {
712        "PERCENTILE"
713    }
714
715    fn description(&self) -> &str {
716        "Calculate the nth percentile of values"
717    }
718
719    fn create_state(&self) -> Box<dyn AggregateState> {
720        Box::new(PercentileState {
721            values: Vec::new(),
722            percentile: 50.0, // Default to median
723        })
724    }
725
726    fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
727        // PERCENTILE takes the percentile value as a parameter
728        if params.is_empty() {
729            return Ok(Box::new(PercentileFunction));
730        }
731
732        let percentile = match &params[0] {
733            DataValue::Integer(i) => *i as f64,
734            DataValue::Float(f) => *f,
735            _ => {
736                return Err(anyhow!(
737                    "PERCENTILE parameter must be a number between 0 and 100"
738                ))
739            }
740        };
741
742        if percentile < 0.0 || percentile > 100.0 {
743            return Err(anyhow!("PERCENTILE must be between 0 and 100"));
744        }
745
746        Ok(Box::new(PercentileWithParam { percentile }))
747    }
748}
749
750struct PercentileWithParam {
751    percentile: f64,
752}
753
754impl AggregateFunction for PercentileWithParam {
755    fn name(&self) -> &str {
756        "PERCENTILE"
757    }
758
759    fn description(&self) -> &str {
760        "Calculate the nth percentile of values"
761    }
762
763    fn create_state(&self) -> Box<dyn AggregateState> {
764        Box::new(PercentileState {
765            values: Vec::new(),
766            percentile: self.percentile,
767        })
768    }
769}
770
771// ============= Collector State for functions that need all values =============
772
773enum CollectorFunction {
774    Median,
775    Mode,
776    StdDev,    // Sample standard deviation
777    StdDevP,   // Population standard deviation
778    Variance,  // Sample variance
779    VarianceP, // Population variance
780}
781
782struct CollectorState {
783    values: Vec<f64>,
784    function_type: CollectorFunction,
785}
786
787impl AggregateState for CollectorState {
788    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
789        match value {
790            DataValue::Null => Ok(()), // Skip nulls
791            DataValue::Integer(n) => {
792                self.values.push(*n as f64);
793                Ok(())
794            }
795            DataValue::Float(f) => {
796                self.values.push(*f);
797                Ok(())
798            }
799            _ => match self.function_type {
800                CollectorFunction::Mode => {
801                    // Mode can work with non-numeric types, but we'll handle that separately
802                    Err(anyhow!("MODE currently only supports numeric values"))
803                }
804                _ => Err(anyhow!("Statistical functions require numeric values")),
805            },
806        }
807    }
808
809    fn finalize(self: Box<Self>) -> DataValue {
810        if self.values.is_empty() {
811            return DataValue::Null;
812        }
813
814        match self.function_type {
815            CollectorFunction::Median => {
816                let mut sorted = self.values.clone();
817                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
818                let len = sorted.len();
819                if len % 2 == 0 {
820                    DataValue::Float((sorted[len / 2 - 1] + sorted[len / 2]) / 2.0)
821                } else {
822                    DataValue::Float(sorted[len / 2])
823                }
824            }
825            CollectorFunction::Mode => {
826                use std::collections::HashMap;
827                let mut counts = HashMap::new();
828                for value in &self.values {
829                    *counts.entry(value.to_bits()).or_insert(0) += 1;
830                }
831                if let Some((bits, _)) = counts.iter().max_by_key(|&(_, count)| count) {
832                    DataValue::Float(f64::from_bits(*bits))
833                } else {
834                    DataValue::Null
835                }
836            }
837            CollectorFunction::StdDev | CollectorFunction::Variance => {
838                // Sample standard deviation and variance
839                if self.values.len() < 2 {
840                    return DataValue::Null;
841                }
842                let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
843                let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
844                    / (self.values.len() - 1) as f64; // N-1 for sample
845
846                match self.function_type {
847                    CollectorFunction::StdDev => DataValue::Float(variance.sqrt()),
848                    CollectorFunction::Variance => DataValue::Float(variance),
849                    _ => unreachable!(),
850                }
851            }
852            CollectorFunction::StdDevP | CollectorFunction::VarianceP => {
853                // Population standard deviation and variance
854                let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
855                let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
856                    / self.values.len() as f64; // N for population
857
858                match self.function_type {
859                    CollectorFunction::StdDevP => DataValue::Float(variance.sqrt()),
860                    CollectorFunction::VarianceP => DataValue::Float(variance),
861                    _ => unreachable!(),
862                }
863            }
864        }
865    }
866
867    fn clone_box(&self) -> Box<dyn AggregateState> {
868        Box::new(CollectorState {
869            values: self.values.clone(),
870            function_type: match self.function_type {
871                CollectorFunction::Median => CollectorFunction::Median,
872                CollectorFunction::Mode => CollectorFunction::Mode,
873                CollectorFunction::StdDev => CollectorFunction::StdDev,
874                CollectorFunction::StdDevP => CollectorFunction::StdDevP,
875                CollectorFunction::Variance => CollectorFunction::Variance,
876                CollectorFunction::VarianceP => CollectorFunction::VarianceP,
877            },
878        })
879    }
880
881    fn reset(&mut self) {
882        self.values.clear();
883    }
884}
885
886// ============= Percentile State =============
887
888struct PercentileState {
889    values: Vec<f64>,
890    percentile: f64,
891}
892
893impl AggregateState for PercentileState {
894    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
895        match value {
896            DataValue::Null => Ok(()), // Skip nulls
897            DataValue::Integer(n) => {
898                self.values.push(*n as f64);
899                Ok(())
900            }
901            DataValue::Float(f) => {
902                self.values.push(*f);
903                Ok(())
904            }
905            _ => Err(anyhow!("PERCENTILE requires numeric values")),
906        }
907    }
908
909    fn finalize(self: Box<Self>) -> DataValue {
910        if self.values.is_empty() {
911            return DataValue::Null;
912        }
913
914        let mut sorted = self.values.clone();
915        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
916
917        // Calculate the position in the sorted array
918        let position = (self.percentile / 100.0) * (sorted.len() - 1) as f64;
919        let lower = position.floor() as usize;
920        let upper = position.ceil() as usize;
921
922        if lower == upper {
923            DataValue::Float(sorted[lower])
924        } else {
925            // Linear interpolation between two values
926            let weight = position - lower as f64;
927            DataValue::Float(sorted[lower] * (1.0 - weight) + sorted[upper] * weight)
928        }
929    }
930
931    fn clone_box(&self) -> Box<dyn AggregateState> {
932        Box::new(PercentileState {
933            values: self.values.clone(),
934            percentile: self.percentile,
935        })
936    }
937
938    fn reset(&mut self) {
939        self.values.clear();
940    }
941}
942
943#[cfg(test)]
944mod tests {
945    use super::*;
946
947    #[test]
948    fn test_registry_creation() {
949        let registry = AggregateFunctionRegistry::new();
950        assert!(registry.contains("COUNT"));
951        assert!(registry.contains("SUM"));
952        assert!(registry.contains("AVG"));
953        assert!(registry.contains("MIN"));
954        assert!(registry.contains("MAX"));
955        assert!(registry.contains("STRING_AGG"));
956    }
957
958    #[test]
959    fn test_count_aggregate() {
960        let func = CountFunction;
961        let mut state = func.create_state();
962
963        state.accumulate(&DataValue::Integer(1)).unwrap();
964        state.accumulate(&DataValue::Null).unwrap();
965        state.accumulate(&DataValue::Integer(3)).unwrap();
966
967        let result = state.finalize();
968        assert_eq!(result, DataValue::Integer(2));
969    }
970
971    #[test]
972    fn test_string_agg() {
973        let func = StringAggFunction::with_separator(", ".to_string());
974        let mut state = func.create_state();
975
976        state
977            .accumulate(&DataValue::String("apple".to_string()))
978            .unwrap();
979        state
980            .accumulate(&DataValue::String("banana".to_string()))
981            .unwrap();
982        state
983            .accumulate(&DataValue::String("cherry".to_string()))
984            .unwrap();
985
986        let result = state.finalize();
987        assert_eq!(
988            result,
989            DataValue::String("apple, banana, cherry".to_string())
990        );
991    }
992}