sql_cli/sql/aggregate_functions/
mod.rs

1// Aggregate Function Registry
2// Provides a clean API for group-based aggregate computations
3// Moves all aggregate logic out of the evaluator into a registry pattern
4
5use anyhow::{anyhow, Result};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::data::datatable::DataValue;
10
11/// State maintained during aggregation
12/// Each aggregate function manages its own state type
13pub trait AggregateState: Send + Sync {
14    /// Add a value to the aggregate
15    fn accumulate(&mut self, value: &DataValue) -> Result<()>;
16
17    /// Finalize and return the aggregate result
18    fn finalize(self: Box<Self>) -> DataValue;
19
20    /// Create a new instance of this state
21    fn clone_box(&self) -> Box<dyn AggregateState>;
22
23    /// Reset the state for reuse
24    fn reset(&mut self);
25}
26
27/// Aggregate function trait
28/// Each aggregate function (SUM, COUNT, AVG, etc.) implements this
29pub trait AggregateFunction: Send + Sync {
30    /// Function name (e.g., "SUM", "COUNT", "STRING_AGG")
31    fn name(&self) -> &str;
32
33    /// Description for help system
34    fn description(&self) -> &str;
35
36    /// Create initial state for this aggregate
37    fn create_state(&self) -> Box<dyn AggregateState>;
38
39    /// Does this aggregate support DISTINCT?
40    fn supports_distinct(&self) -> bool {
41        true // Most aggregates should support DISTINCT
42    }
43
44    /// For aggregates with parameters (like STRING_AGG separator)
45    fn set_parameters(&self, _params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
46        // Default implementation for aggregates without parameters
47        Ok(Box::new(DummyClone(self.name().to_string())))
48    }
49}
50
51// Dummy clone helper for default implementation
52struct DummyClone(String);
53impl AggregateFunction for DummyClone {
54    fn name(&self) -> &str {
55        &self.0
56    }
57    fn description(&self) -> &str {
58        ""
59    }
60    fn create_state(&self) -> Box<dyn AggregateState> {
61        panic!("DummyClone should not be used")
62    }
63}
64
65/// Registry for aggregate functions
66pub struct AggregateFunctionRegistry {
67    functions: HashMap<String, Arc<Box<dyn AggregateFunction>>>,
68}
69
70impl AggregateFunctionRegistry {
71    pub fn new() -> Self {
72        let mut registry = Self {
73            functions: HashMap::new(),
74        };
75        registry.register_builtin_functions();
76        registry
77    }
78
79    /// Register an aggregate function
80    pub fn register(&mut self, function: Box<dyn AggregateFunction>) {
81        let name = function.name().to_uppercase();
82        self.functions.insert(name, Arc::new(function));
83    }
84
85    /// Get an aggregate function by name
86    pub fn get(&self, name: &str) -> Option<Arc<Box<dyn AggregateFunction>>> {
87        self.functions.get(&name.to_uppercase()).cloned()
88    }
89
90    /// Check if a function exists
91    pub fn contains(&self, name: &str) -> bool {
92        self.functions.contains_key(&name.to_uppercase())
93    }
94
95    /// List all registered functions
96    pub fn list_functions(&self) -> Vec<String> {
97        self.functions.keys().cloned().collect()
98    }
99
100    /// Register built-in aggregate functions
101    fn register_builtin_functions(&mut self) {
102        // Basic aggregates
103        self.register(Box::new(CountFunction));
104        self.register(Box::new(CountStarFunction));
105        self.register(Box::new(SumFunction));
106        self.register(Box::new(AvgFunction));
107        self.register(Box::new(MinFunction));
108        self.register(Box::new(MaxFunction));
109
110        // String aggregates
111        self.register(Box::new(StringAggFunction::new()));
112
113        // Statistical aggregates
114        self.register(Box::new(MedianFunction));
115        self.register(Box::new(ModeFunction));
116        self.register(Box::new(StdDevFunction));
117        self.register(Box::new(StdDevPFunction));
118        self.register(Box::new(VarianceFunction));
119        self.register(Box::new(VariancePFunction));
120        self.register(Box::new(PercentileFunction));
121    }
122}
123
124// ============= COUNT Implementation =============
125
126struct CountFunction;
127
128impl AggregateFunction for CountFunction {
129    fn name(&self) -> &str {
130        "COUNT"
131    }
132
133    fn description(&self) -> &str {
134        "Count the number of non-null values"
135    }
136
137    fn create_state(&self) -> Box<dyn AggregateState> {
138        Box::new(CountState { count: 0 })
139    }
140}
141
142struct CountState {
143    count: i64,
144}
145
146impl AggregateState for CountState {
147    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
148        // COUNT(column) counts non-nulls only
149        if !matches!(value, DataValue::Null) {
150            self.count += 1;
151        }
152        Ok(())
153    }
154
155    fn finalize(self: Box<Self>) -> DataValue {
156        DataValue::Integer(self.count)
157    }
158
159    fn clone_box(&self) -> Box<dyn AggregateState> {
160        Box::new(CountState { count: self.count })
161    }
162
163    fn reset(&mut self) {
164        self.count = 0;
165    }
166}
167
168// COUNT(*) - counts all rows including nulls
169struct CountStarFunction;
170
171impl AggregateFunction for CountStarFunction {
172    fn name(&self) -> &str {
173        "COUNT_STAR"
174    }
175
176    fn description(&self) -> &str {
177        "Count all rows including nulls"
178    }
179
180    fn create_state(&self) -> Box<dyn AggregateState> {
181        Box::new(CountStarState { count: 0 })
182    }
183}
184
185struct CountStarState {
186    count: i64,
187}
188
189impl AggregateState for CountStarState {
190    fn accumulate(&mut self, _value: &DataValue) -> Result<()> {
191        // COUNT(*) counts all rows, even nulls
192        self.count += 1;
193        Ok(())
194    }
195
196    fn finalize(self: Box<Self>) -> DataValue {
197        DataValue::Integer(self.count)
198    }
199
200    fn clone_box(&self) -> Box<dyn AggregateState> {
201        Box::new(CountStarState { count: self.count })
202    }
203
204    fn reset(&mut self) {
205        self.count = 0;
206    }
207}
208
209// ============= SUM Implementation =============
210
211struct SumFunction;
212
213impl AggregateFunction for SumFunction {
214    fn name(&self) -> &str {
215        "SUM"
216    }
217
218    fn description(&self) -> &str {
219        "Calculate the sum of values"
220    }
221
222    fn create_state(&self) -> Box<dyn AggregateState> {
223        Box::new(SumState {
224            int_sum: None,
225            float_sum: None,
226            has_values: false,
227        })
228    }
229}
230
231struct SumState {
232    int_sum: Option<i64>,
233    float_sum: Option<f64>,
234    has_values: bool,
235}
236
237impl AggregateState for SumState {
238    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
239        match value {
240            DataValue::Null => Ok(()), // Skip nulls
241            DataValue::Integer(n) => {
242                self.has_values = true;
243                if let Some(ref mut sum) = self.int_sum {
244                    *sum = sum.saturating_add(*n);
245                } else if let Some(ref mut fsum) = self.float_sum {
246                    *fsum += *n as f64;
247                } else {
248                    self.int_sum = Some(*n);
249                }
250                Ok(())
251            }
252            DataValue::Float(f) => {
253                self.has_values = true;
254                // Once we have a float, convert everything to float
255                if let Some(isum) = self.int_sum.take() {
256                    self.float_sum = Some(isum as f64 + f);
257                } else if let Some(ref mut fsum) = self.float_sum {
258                    *fsum += f;
259                } else {
260                    self.float_sum = Some(*f);
261                }
262                Ok(())
263            }
264            _ => Err(anyhow!("Cannot sum non-numeric value")),
265        }
266    }
267
268    fn finalize(self: Box<Self>) -> DataValue {
269        if !self.has_values {
270            return DataValue::Null;
271        }
272
273        if let Some(fsum) = self.float_sum {
274            DataValue::Float(fsum)
275        } else if let Some(isum) = self.int_sum {
276            DataValue::Integer(isum)
277        } else {
278            DataValue::Null
279        }
280    }
281
282    fn clone_box(&self) -> Box<dyn AggregateState> {
283        Box::new(SumState {
284            int_sum: self.int_sum,
285            float_sum: self.float_sum,
286            has_values: self.has_values,
287        })
288    }
289
290    fn reset(&mut self) {
291        self.int_sum = None;
292        self.float_sum = None;
293        self.has_values = false;
294    }
295}
296
297// ============= AVG Implementation =============
298
299struct AvgFunction;
300
301impl AggregateFunction for AvgFunction {
302    fn name(&self) -> &str {
303        "AVG"
304    }
305
306    fn description(&self) -> &str {
307        "Calculate the average of values"
308    }
309
310    fn create_state(&self) -> Box<dyn AggregateState> {
311        Box::new(AvgState {
312            sum: SumState {
313                int_sum: None,
314                float_sum: None,
315                has_values: false,
316            },
317            count: 0,
318        })
319    }
320}
321
322struct AvgState {
323    sum: SumState,
324    count: i64,
325}
326
327impl AggregateState for AvgState {
328    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
329        if !matches!(value, DataValue::Null) {
330            self.sum.accumulate(value)?;
331            self.count += 1;
332        }
333        Ok(())
334    }
335
336    fn finalize(self: Box<Self>) -> DataValue {
337        if self.count == 0 {
338            return DataValue::Null;
339        }
340
341        let sum = Box::new(self.sum).finalize();
342        match sum {
343            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
344            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
345            _ => DataValue::Null,
346        }
347    }
348
349    fn clone_box(&self) -> Box<dyn AggregateState> {
350        Box::new(AvgState {
351            sum: SumState {
352                int_sum: self.sum.int_sum,
353                float_sum: self.sum.float_sum,
354                has_values: self.sum.has_values,
355            },
356            count: self.count,
357        })
358    }
359
360    fn reset(&mut self) {
361        self.sum.reset();
362        self.count = 0;
363    }
364}
365
366// ============= MIN Implementation =============
367
368struct MinFunction;
369
370impl AggregateFunction for MinFunction {
371    fn name(&self) -> &str {
372        "MIN"
373    }
374
375    fn description(&self) -> &str {
376        "Find the minimum value"
377    }
378
379    fn create_state(&self) -> Box<dyn AggregateState> {
380        Box::new(MinMaxState {
381            is_min: true,
382            current: None,
383        })
384    }
385}
386
387// ============= MAX Implementation =============
388
389struct MaxFunction;
390
391impl AggregateFunction for MaxFunction {
392    fn name(&self) -> &str {
393        "MAX"
394    }
395
396    fn description(&self) -> &str {
397        "Find the maximum value"
398    }
399
400    fn create_state(&self) -> Box<dyn AggregateState> {
401        Box::new(MinMaxState {
402            is_min: false,
403            current: None,
404        })
405    }
406}
407
408struct MinMaxState {
409    is_min: bool,
410    current: Option<DataValue>,
411}
412
413impl AggregateState for MinMaxState {
414    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
415        if matches!(value, DataValue::Null) {
416            return Ok(());
417        }
418
419        match &self.current {
420            None => {
421                self.current = Some(value.clone());
422            }
423            Some(current) => {
424                let should_update = if self.is_min {
425                    value < current
426                } else {
427                    value > current
428                };
429
430                if should_update {
431                    self.current = Some(value.clone());
432                }
433            }
434        }
435
436        Ok(())
437    }
438
439    fn finalize(self: Box<Self>) -> DataValue {
440        self.current.unwrap_or(DataValue::Null)
441    }
442
443    fn clone_box(&self) -> Box<dyn AggregateState> {
444        Box::new(MinMaxState {
445            is_min: self.is_min,
446            current: self.current.clone(),
447        })
448    }
449
450    fn reset(&mut self) {
451        self.current = None;
452    }
453}
454
455// ============= STRING_AGG Implementation =============
456
457struct StringAggFunction {
458    separator: String,
459}
460
461impl StringAggFunction {
462    fn new() -> Self {
463        Self {
464            separator: ",".to_string(), // Default separator
465        }
466    }
467
468    fn with_separator(separator: String) -> Self {
469        Self { separator }
470    }
471}
472
473impl AggregateFunction for StringAggFunction {
474    fn name(&self) -> &str {
475        "STRING_AGG"
476    }
477
478    fn description(&self) -> &str {
479        "Concatenate strings with a separator"
480    }
481
482    fn create_state(&self) -> Box<dyn AggregateState> {
483        Box::new(StringAggState {
484            values: Vec::new(),
485            separator: self.separator.clone(),
486        })
487    }
488
489    fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
490        // STRING_AGG takes a separator as second parameter
491        if params.is_empty() {
492            return Ok(Box::new(StringAggFunction::new()));
493        }
494
495        let separator = match &params[0] {
496            DataValue::String(s) => s.clone(),
497            DataValue::InternedString(s) => s.to_string(),
498            _ => return Err(anyhow!("STRING_AGG separator must be a string")),
499        };
500
501        Ok(Box::new(StringAggFunction::with_separator(separator)))
502    }
503}
504
505struct StringAggState {
506    values: Vec<String>,
507    separator: String,
508}
509
510impl AggregateState for StringAggState {
511    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
512        match value {
513            DataValue::Null => Ok(()), // Skip nulls
514            DataValue::String(s) => {
515                self.values.push(s.clone());
516                Ok(())
517            }
518            DataValue::InternedString(s) => {
519                self.values.push(s.to_string());
520                Ok(())
521            }
522            DataValue::Integer(n) => {
523                self.values.push(n.to_string());
524                Ok(())
525            }
526            DataValue::Float(f) => {
527                self.values.push(f.to_string());
528                Ok(())
529            }
530            DataValue::Boolean(b) => {
531                self.values.push(b.to_string());
532                Ok(())
533            }
534            DataValue::DateTime(dt) => {
535                self.values.push(dt.to_string());
536                Ok(())
537            }
538            DataValue::Vector(v) => {
539                let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
540                self.values.push(format!("[{}]", components.join(",")));
541                Ok(())
542            }
543        }
544    }
545
546    fn finalize(self: Box<Self>) -> DataValue {
547        if self.values.is_empty() {
548            DataValue::Null
549        } else {
550            DataValue::String(self.values.join(&self.separator))
551        }
552    }
553
554    fn clone_box(&self) -> Box<dyn AggregateState> {
555        Box::new(StringAggState {
556            values: self.values.clone(),
557            separator: self.separator.clone(),
558        })
559    }
560
561    fn reset(&mut self) {
562        self.values.clear();
563    }
564}
565
566// ============= MEDIAN Implementation =============
567
568struct MedianFunction;
569
570impl AggregateFunction for MedianFunction {
571    fn name(&self) -> &str {
572        "MEDIAN"
573    }
574
575    fn description(&self) -> &str {
576        "Calculate the median (middle value) of numeric values"
577    }
578
579    fn create_state(&self) -> Box<dyn AggregateState> {
580        Box::new(CollectorState {
581            values: Vec::new(),
582            function_type: CollectorFunction::Median,
583        })
584    }
585}
586
587// ============= MODE Implementation =============
588
589struct ModeFunction;
590
591impl AggregateFunction for ModeFunction {
592    fn name(&self) -> &str {
593        "MODE"
594    }
595
596    fn description(&self) -> &str {
597        "Find the most frequently occurring value"
598    }
599
600    fn create_state(&self) -> Box<dyn AggregateState> {
601        Box::new(CollectorState {
602            values: Vec::new(),
603            function_type: CollectorFunction::Mode,
604        })
605    }
606}
607
608// ============= STDDEV Implementation =============
609
610struct StdDevFunction;
611
612impl AggregateFunction for StdDevFunction {
613    fn name(&self) -> &str {
614        "STDDEV"
615    }
616
617    fn description(&self) -> &str {
618        "Calculate the sample standard deviation"
619    }
620
621    fn create_state(&self) -> Box<dyn AggregateState> {
622        Box::new(CollectorState {
623            values: Vec::new(),
624            function_type: CollectorFunction::StdDev,
625        })
626    }
627}
628
629// ============= STDDEV_POP Implementation =============
630
631struct StdDevPFunction;
632
633impl AggregateFunction for StdDevPFunction {
634    fn name(&self) -> &str {
635        "STDDEV_POP"
636    }
637
638    fn description(&self) -> &str {
639        "Calculate the population standard deviation"
640    }
641
642    fn create_state(&self) -> Box<dyn AggregateState> {
643        Box::new(CollectorState {
644            values: Vec::new(),
645            function_type: CollectorFunction::StdDevP,
646        })
647    }
648}
649
650// ============= VARIANCE Implementation =============
651
652struct VarianceFunction;
653
654impl AggregateFunction for VarianceFunction {
655    fn name(&self) -> &str {
656        "VARIANCE"
657    }
658
659    fn description(&self) -> &str {
660        "Calculate the sample variance"
661    }
662
663    fn create_state(&self) -> Box<dyn AggregateState> {
664        Box::new(CollectorState {
665            values: Vec::new(),
666            function_type: CollectorFunction::Variance,
667        })
668    }
669}
670
671// ============= VARIANCE_POP Implementation =============
672
673struct VariancePFunction;
674
675impl AggregateFunction for VariancePFunction {
676    fn name(&self) -> &str {
677        "VARIANCE_POP"
678    }
679
680    fn description(&self) -> &str {
681        "Calculate the population variance"
682    }
683
684    fn create_state(&self) -> Box<dyn AggregateState> {
685        Box::new(CollectorState {
686            values: Vec::new(),
687            function_type: CollectorFunction::VarianceP,
688        })
689    }
690}
691
692// ============= PERCENTILE Implementation =============
693
694struct PercentileFunction;
695
696impl AggregateFunction for PercentileFunction {
697    fn name(&self) -> &str {
698        "PERCENTILE"
699    }
700
701    fn description(&self) -> &str {
702        "Calculate the nth percentile of values"
703    }
704
705    fn create_state(&self) -> Box<dyn AggregateState> {
706        Box::new(PercentileState {
707            values: Vec::new(),
708            percentile: 50.0, // Default to median
709        })
710    }
711
712    fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
713        // PERCENTILE takes the percentile value as a parameter
714        if params.is_empty() {
715            return Ok(Box::new(PercentileFunction));
716        }
717
718        let percentile = match &params[0] {
719            DataValue::Integer(i) => *i as f64,
720            DataValue::Float(f) => *f,
721            _ => {
722                return Err(anyhow!(
723                    "PERCENTILE parameter must be a number between 0 and 100"
724                ))
725            }
726        };
727
728        if percentile < 0.0 || percentile > 100.0 {
729            return Err(anyhow!("PERCENTILE must be between 0 and 100"));
730        }
731
732        Ok(Box::new(PercentileWithParam { percentile }))
733    }
734}
735
736struct PercentileWithParam {
737    percentile: f64,
738}
739
740impl AggregateFunction for PercentileWithParam {
741    fn name(&self) -> &str {
742        "PERCENTILE"
743    }
744
745    fn description(&self) -> &str {
746        "Calculate the nth percentile of values"
747    }
748
749    fn create_state(&self) -> Box<dyn AggregateState> {
750        Box::new(PercentileState {
751            values: Vec::new(),
752            percentile: self.percentile,
753        })
754    }
755}
756
757// ============= Collector State for functions that need all values =============
758
759enum CollectorFunction {
760    Median,
761    Mode,
762    StdDev,    // Sample standard deviation
763    StdDevP,   // Population standard deviation
764    Variance,  // Sample variance
765    VarianceP, // Population variance
766}
767
768struct CollectorState {
769    values: Vec<f64>,
770    function_type: CollectorFunction,
771}
772
773impl AggregateState for CollectorState {
774    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
775        match value {
776            DataValue::Null => Ok(()), // Skip nulls
777            DataValue::Integer(n) => {
778                self.values.push(*n as f64);
779                Ok(())
780            }
781            DataValue::Float(f) => {
782                self.values.push(*f);
783                Ok(())
784            }
785            _ => match self.function_type {
786                CollectorFunction::Mode => {
787                    // Mode can work with non-numeric types, but we'll handle that separately
788                    Err(anyhow!("MODE currently only supports numeric values"))
789                }
790                _ => Err(anyhow!("Statistical functions require numeric values")),
791            },
792        }
793    }
794
795    fn finalize(self: Box<Self>) -> DataValue {
796        if self.values.is_empty() {
797            return DataValue::Null;
798        }
799
800        match self.function_type {
801            CollectorFunction::Median => {
802                let mut sorted = self.values.clone();
803                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
804                let len = sorted.len();
805                if len % 2 == 0 {
806                    DataValue::Float((sorted[len / 2 - 1] + sorted[len / 2]) / 2.0)
807                } else {
808                    DataValue::Float(sorted[len / 2])
809                }
810            }
811            CollectorFunction::Mode => {
812                use std::collections::HashMap;
813                let mut counts = HashMap::new();
814                for value in &self.values {
815                    *counts.entry(value.to_bits()).or_insert(0) += 1;
816                }
817                if let Some((bits, _)) = counts.iter().max_by_key(|&(_, count)| count) {
818                    DataValue::Float(f64::from_bits(*bits))
819                } else {
820                    DataValue::Null
821                }
822            }
823            CollectorFunction::StdDev | CollectorFunction::Variance => {
824                // Sample standard deviation and variance
825                if self.values.len() < 2 {
826                    return DataValue::Null;
827                }
828                let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
829                let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
830                    / (self.values.len() - 1) as f64; // N-1 for sample
831
832                match self.function_type {
833                    CollectorFunction::StdDev => DataValue::Float(variance.sqrt()),
834                    CollectorFunction::Variance => DataValue::Float(variance),
835                    _ => unreachable!(),
836                }
837            }
838            CollectorFunction::StdDevP | CollectorFunction::VarianceP => {
839                // Population standard deviation and variance
840                let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
841                let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
842                    / self.values.len() as f64; // N for population
843
844                match self.function_type {
845                    CollectorFunction::StdDevP => DataValue::Float(variance.sqrt()),
846                    CollectorFunction::VarianceP => DataValue::Float(variance),
847                    _ => unreachable!(),
848                }
849            }
850        }
851    }
852
853    fn clone_box(&self) -> Box<dyn AggregateState> {
854        Box::new(CollectorState {
855            values: self.values.clone(),
856            function_type: match self.function_type {
857                CollectorFunction::Median => CollectorFunction::Median,
858                CollectorFunction::Mode => CollectorFunction::Mode,
859                CollectorFunction::StdDev => CollectorFunction::StdDev,
860                CollectorFunction::StdDevP => CollectorFunction::StdDevP,
861                CollectorFunction::Variance => CollectorFunction::Variance,
862                CollectorFunction::VarianceP => CollectorFunction::VarianceP,
863            },
864        })
865    }
866
867    fn reset(&mut self) {
868        self.values.clear();
869    }
870}
871
872// ============= Percentile State =============
873
874struct PercentileState {
875    values: Vec<f64>,
876    percentile: f64,
877}
878
879impl AggregateState for PercentileState {
880    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
881        match value {
882            DataValue::Null => Ok(()), // Skip nulls
883            DataValue::Integer(n) => {
884                self.values.push(*n as f64);
885                Ok(())
886            }
887            DataValue::Float(f) => {
888                self.values.push(*f);
889                Ok(())
890            }
891            _ => Err(anyhow!("PERCENTILE requires numeric values")),
892        }
893    }
894
895    fn finalize(self: Box<Self>) -> DataValue {
896        if self.values.is_empty() {
897            return DataValue::Null;
898        }
899
900        let mut sorted = self.values.clone();
901        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
902
903        // Calculate the position in the sorted array
904        let position = (self.percentile / 100.0) * (sorted.len() - 1) as f64;
905        let lower = position.floor() as usize;
906        let upper = position.ceil() as usize;
907
908        if lower == upper {
909            DataValue::Float(sorted[lower])
910        } else {
911            // Linear interpolation between two values
912            let weight = position - lower as f64;
913            DataValue::Float(sorted[lower] * (1.0 - weight) + sorted[upper] * weight)
914        }
915    }
916
917    fn clone_box(&self) -> Box<dyn AggregateState> {
918        Box::new(PercentileState {
919            values: self.values.clone(),
920            percentile: self.percentile,
921        })
922    }
923
924    fn reset(&mut self) {
925        self.values.clear();
926    }
927}
928
929#[cfg(test)]
930mod tests {
931    use super::*;
932
933    #[test]
934    fn test_registry_creation() {
935        let registry = AggregateFunctionRegistry::new();
936        assert!(registry.contains("COUNT"));
937        assert!(registry.contains("SUM"));
938        assert!(registry.contains("AVG"));
939        assert!(registry.contains("MIN"));
940        assert!(registry.contains("MAX"));
941        assert!(registry.contains("STRING_AGG"));
942    }
943
944    #[test]
945    fn test_count_aggregate() {
946        let func = CountFunction;
947        let mut state = func.create_state();
948
949        state.accumulate(&DataValue::Integer(1)).unwrap();
950        state.accumulate(&DataValue::Null).unwrap();
951        state.accumulate(&DataValue::Integer(3)).unwrap();
952
953        let result = state.finalize();
954        assert_eq!(result, DataValue::Integer(2));
955    }
956
957    #[test]
958    fn test_string_agg() {
959        let func = StringAggFunction::with_separator(", ".to_string());
960        let mut state = func.create_state();
961
962        state
963            .accumulate(&DataValue::String("apple".to_string()))
964            .unwrap();
965        state
966            .accumulate(&DataValue::String("banana".to_string()))
967            .unwrap();
968        state
969            .accumulate(&DataValue::String("cherry".to_string()))
970            .unwrap();
971
972        let result = state.finalize();
973        assert_eq!(
974            result,
975            DataValue::String("apple, banana, cherry".to_string())
976        );
977    }
978}