sql_cli/sql/aggregate_functions/
mod.rs

1// Aggregate Function Registry
2// Provides a clean API for group-based aggregate computations
3// Moves all aggregate logic out of the evaluator into a registry pattern
4
5use anyhow::{anyhow, Result};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::data::datatable::DataValue;
10
11/// State maintained during aggregation
12/// Each aggregate function manages its own state type
13pub trait AggregateState: Send + Sync {
14    /// Add a value to the aggregate
15    fn accumulate(&mut self, value: &DataValue) -> Result<()>;
16
17    /// Finalize and return the aggregate result
18    fn finalize(self: Box<Self>) -> DataValue;
19
20    /// Create a new instance of this state
21    fn clone_box(&self) -> Box<dyn AggregateState>;
22
23    /// Reset the state for reuse
24    fn reset(&mut self);
25}
26
27/// Aggregate function trait
28/// Each aggregate function (SUM, COUNT, AVG, etc.) implements this
29pub trait AggregateFunction: Send + Sync {
30    /// Function name (e.g., "SUM", "COUNT", "STRING_AGG")
31    fn name(&self) -> &str;
32
33    /// Description for help system
34    fn description(&self) -> &str;
35
36    /// Create initial state for this aggregate
37    fn create_state(&self) -> Box<dyn AggregateState>;
38
39    /// Does this aggregate support DISTINCT?
40    fn supports_distinct(&self) -> bool {
41        true // Most aggregates should support DISTINCT
42    }
43
44    /// For aggregates with parameters (like STRING_AGG separator)
45    fn set_parameters(&self, _params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
46        // Default implementation for aggregates without parameters
47        Ok(Box::new(DummyClone(self.name().to_string())))
48    }
49}
50
51// Dummy clone helper for default implementation
52struct DummyClone(String);
53impl AggregateFunction for DummyClone {
54    fn name(&self) -> &str {
55        &self.0
56    }
57    fn description(&self) -> &str {
58        ""
59    }
60    fn create_state(&self) -> Box<dyn AggregateState> {
61        panic!("DummyClone should not be used")
62    }
63}
64
65/// Registry for aggregate functions
66pub struct AggregateFunctionRegistry {
67    functions: HashMap<String, Arc<Box<dyn AggregateFunction>>>,
68}
69
70impl AggregateFunctionRegistry {
71    pub fn new() -> Self {
72        let mut registry = Self {
73            functions: HashMap::new(),
74        };
75        registry.register_builtin_functions();
76        registry
77    }
78
79    /// Register an aggregate function
80    pub fn register(&mut self, function: Box<dyn AggregateFunction>) {
81        let name = function.name().to_uppercase();
82        self.functions.insert(name, Arc::new(function));
83    }
84
85    /// Get an aggregate function by name
86    pub fn get(&self, name: &str) -> Option<Arc<Box<dyn AggregateFunction>>> {
87        self.functions.get(&name.to_uppercase()).cloned()
88    }
89
90    /// Check if a function exists
91    pub fn contains(&self, name: &str) -> bool {
92        self.functions.contains_key(&name.to_uppercase())
93    }
94
95    /// List all registered functions
96    pub fn list_functions(&self) -> Vec<String> {
97        self.functions.keys().cloned().collect()
98    }
99
100    /// Register built-in aggregate functions
101    fn register_builtin_functions(&mut self) {
102        // Basic aggregates
103        self.register(Box::new(CountFunction));
104        self.register(Box::new(CountStarFunction));
105        self.register(Box::new(SumFunction));
106        self.register(Box::new(AvgFunction));
107        self.register(Box::new(MinFunction));
108        self.register(Box::new(MaxFunction));
109
110        // String aggregates
111        self.register(Box::new(StringAggFunction::new()));
112
113        // Statistical aggregates
114        self.register(Box::new(MedianFunction));
115        self.register(Box::new(ModeFunction));
116        self.register(Box::new(StdDevFunction));
117        self.register(Box::new(StdDevPFunction));
118        self.register(Box::new(VarianceFunction));
119        self.register(Box::new(VariancePFunction));
120        self.register(Box::new(PercentileFunction));
121    }
122}
123
124// ============= COUNT Implementation =============
125
126struct CountFunction;
127
128impl AggregateFunction for CountFunction {
129    fn name(&self) -> &str {
130        "COUNT"
131    }
132
133    fn description(&self) -> &str {
134        "Count the number of non-null values"
135    }
136
137    fn create_state(&self) -> Box<dyn AggregateState> {
138        Box::new(CountState { count: 0 })
139    }
140}
141
142struct CountState {
143    count: i64,
144}
145
146impl AggregateState for CountState {
147    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
148        // COUNT(column) counts non-nulls only
149        if !matches!(value, DataValue::Null) {
150            self.count += 1;
151        }
152        Ok(())
153    }
154
155    fn finalize(self: Box<Self>) -> DataValue {
156        DataValue::Integer(self.count)
157    }
158
159    fn clone_box(&self) -> Box<dyn AggregateState> {
160        Box::new(CountState { count: self.count })
161    }
162
163    fn reset(&mut self) {
164        self.count = 0;
165    }
166}
167
168// COUNT(*) - counts all rows including nulls
169struct CountStarFunction;
170
171impl AggregateFunction for CountStarFunction {
172    fn name(&self) -> &str {
173        "COUNT_STAR"
174    }
175
176    fn description(&self) -> &str {
177        "Count all rows including nulls"
178    }
179
180    fn create_state(&self) -> Box<dyn AggregateState> {
181        Box::new(CountStarState { count: 0 })
182    }
183}
184
185struct CountStarState {
186    count: i64,
187}
188
189impl AggregateState for CountStarState {
190    fn accumulate(&mut self, _value: &DataValue) -> Result<()> {
191        // COUNT(*) counts all rows, even nulls
192        self.count += 1;
193        Ok(())
194    }
195
196    fn finalize(self: Box<Self>) -> DataValue {
197        DataValue::Integer(self.count)
198    }
199
200    fn clone_box(&self) -> Box<dyn AggregateState> {
201        Box::new(CountStarState { count: self.count })
202    }
203
204    fn reset(&mut self) {
205        self.count = 0;
206    }
207}
208
209// ============= SUM Implementation =============
210
211struct SumFunction;
212
213impl AggregateFunction for SumFunction {
214    fn name(&self) -> &str {
215        "SUM"
216    }
217
218    fn description(&self) -> &str {
219        "Calculate the sum of values"
220    }
221
222    fn create_state(&self) -> Box<dyn AggregateState> {
223        Box::new(SumState {
224            int_sum: None,
225            float_sum: None,
226            has_values: false,
227        })
228    }
229}
230
231struct SumState {
232    int_sum: Option<i64>,
233    float_sum: Option<f64>,
234    has_values: bool,
235}
236
237impl AggregateState for SumState {
238    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
239        match value {
240            DataValue::Null => Ok(()), // Skip nulls
241            DataValue::Integer(n) => {
242                self.has_values = true;
243                if let Some(ref mut sum) = self.int_sum {
244                    *sum = sum.saturating_add(*n);
245                } else if let Some(ref mut fsum) = self.float_sum {
246                    *fsum += *n as f64;
247                } else {
248                    self.int_sum = Some(*n);
249                }
250                Ok(())
251            }
252            DataValue::Float(f) => {
253                self.has_values = true;
254                // Once we have a float, convert everything to float
255                if let Some(isum) = self.int_sum.take() {
256                    self.float_sum = Some(isum as f64 + f);
257                } else if let Some(ref mut fsum) = self.float_sum {
258                    *fsum += f;
259                } else {
260                    self.float_sum = Some(*f);
261                }
262                Ok(())
263            }
264            _ => Err(anyhow!("Cannot sum non-numeric value")),
265        }
266    }
267
268    fn finalize(self: Box<Self>) -> DataValue {
269        if !self.has_values {
270            return DataValue::Null;
271        }
272
273        if let Some(fsum) = self.float_sum {
274            DataValue::Float(fsum)
275        } else if let Some(isum) = self.int_sum {
276            DataValue::Integer(isum)
277        } else {
278            DataValue::Null
279        }
280    }
281
282    fn clone_box(&self) -> Box<dyn AggregateState> {
283        Box::new(SumState {
284            int_sum: self.int_sum,
285            float_sum: self.float_sum,
286            has_values: self.has_values,
287        })
288    }
289
290    fn reset(&mut self) {
291        self.int_sum = None;
292        self.float_sum = None;
293        self.has_values = false;
294    }
295}
296
297// ============= AVG Implementation =============
298
299struct AvgFunction;
300
301impl AggregateFunction for AvgFunction {
302    fn name(&self) -> &str {
303        "AVG"
304    }
305
306    fn description(&self) -> &str {
307        "Calculate the average of values"
308    }
309
310    fn create_state(&self) -> Box<dyn AggregateState> {
311        Box::new(AvgState {
312            sum: SumState {
313                int_sum: None,
314                float_sum: None,
315                has_values: false,
316            },
317            count: 0,
318        })
319    }
320}
321
322struct AvgState {
323    sum: SumState,
324    count: i64,
325}
326
327impl AggregateState for AvgState {
328    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
329        if !matches!(value, DataValue::Null) {
330            self.sum.accumulate(value)?;
331            self.count += 1;
332        }
333        Ok(())
334    }
335
336    fn finalize(self: Box<Self>) -> DataValue {
337        if self.count == 0 {
338            return DataValue::Null;
339        }
340
341        let sum = Box::new(self.sum).finalize();
342        match sum {
343            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
344            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
345            _ => DataValue::Null,
346        }
347    }
348
349    fn clone_box(&self) -> Box<dyn AggregateState> {
350        Box::new(AvgState {
351            sum: SumState {
352                int_sum: self.sum.int_sum,
353                float_sum: self.sum.float_sum,
354                has_values: self.sum.has_values,
355            },
356            count: self.count,
357        })
358    }
359
360    fn reset(&mut self) {
361        self.sum.reset();
362        self.count = 0;
363    }
364}
365
366// ============= MIN Implementation =============
367
368struct MinFunction;
369
370impl AggregateFunction for MinFunction {
371    fn name(&self) -> &str {
372        "MIN"
373    }
374
375    fn description(&self) -> &str {
376        "Find the minimum value"
377    }
378
379    fn create_state(&self) -> Box<dyn AggregateState> {
380        Box::new(MinMaxState {
381            is_min: true,
382            current: None,
383        })
384    }
385}
386
387// ============= MAX Implementation =============
388
389struct MaxFunction;
390
391impl AggregateFunction for MaxFunction {
392    fn name(&self) -> &str {
393        "MAX"
394    }
395
396    fn description(&self) -> &str {
397        "Find the maximum value"
398    }
399
400    fn create_state(&self) -> Box<dyn AggregateState> {
401        Box::new(MinMaxState {
402            is_min: false,
403            current: None,
404        })
405    }
406}
407
408struct MinMaxState {
409    is_min: bool,
410    current: Option<DataValue>,
411}
412
413impl AggregateState for MinMaxState {
414    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
415        if matches!(value, DataValue::Null) {
416            return Ok(());
417        }
418
419        match &self.current {
420            None => {
421                self.current = Some(value.clone());
422            }
423            Some(current) => {
424                let should_update = if self.is_min {
425                    value < current
426                } else {
427                    value > current
428                };
429
430                if should_update {
431                    self.current = Some(value.clone());
432                }
433            }
434        }
435
436        Ok(())
437    }
438
439    fn finalize(self: Box<Self>) -> DataValue {
440        self.current.unwrap_or(DataValue::Null)
441    }
442
443    fn clone_box(&self) -> Box<dyn AggregateState> {
444        Box::new(MinMaxState {
445            is_min: self.is_min,
446            current: self.current.clone(),
447        })
448    }
449
450    fn reset(&mut self) {
451        self.current = None;
452    }
453}
454
455// ============= STRING_AGG Implementation =============
456
457struct StringAggFunction {
458    separator: String,
459}
460
461impl StringAggFunction {
462    fn new() -> Self {
463        Self {
464            separator: ",".to_string(), // Default separator
465        }
466    }
467
468    fn with_separator(separator: String) -> Self {
469        Self { separator }
470    }
471}
472
473impl AggregateFunction for StringAggFunction {
474    fn name(&self) -> &str {
475        "STRING_AGG"
476    }
477
478    fn description(&self) -> &str {
479        "Concatenate strings with a separator"
480    }
481
482    fn create_state(&self) -> Box<dyn AggregateState> {
483        Box::new(StringAggState {
484            values: Vec::new(),
485            separator: self.separator.clone(),
486        })
487    }
488
489    fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
490        // STRING_AGG takes a separator as second parameter
491        if params.is_empty() {
492            return Ok(Box::new(StringAggFunction::new()));
493        }
494
495        let separator = match &params[0] {
496            DataValue::String(s) => s.clone(),
497            DataValue::InternedString(s) => s.to_string(),
498            _ => return Err(anyhow!("STRING_AGG separator must be a string")),
499        };
500
501        Ok(Box::new(StringAggFunction::with_separator(separator)))
502    }
503}
504
505struct StringAggState {
506    values: Vec<String>,
507    separator: String,
508}
509
510impl AggregateState for StringAggState {
511    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
512        match value {
513            DataValue::Null => Ok(()), // Skip nulls
514            DataValue::String(s) => {
515                self.values.push(s.clone());
516                Ok(())
517            }
518            DataValue::InternedString(s) => {
519                self.values.push(s.to_string());
520                Ok(())
521            }
522            DataValue::Integer(n) => {
523                self.values.push(n.to_string());
524                Ok(())
525            }
526            DataValue::Float(f) => {
527                self.values.push(f.to_string());
528                Ok(())
529            }
530            DataValue::Boolean(b) => {
531                self.values.push(b.to_string());
532                Ok(())
533            }
534            DataValue::DateTime(dt) => {
535                self.values.push(dt.to_string());
536                Ok(())
537            }
538        }
539    }
540
541    fn finalize(self: Box<Self>) -> DataValue {
542        if self.values.is_empty() {
543            DataValue::Null
544        } else {
545            DataValue::String(self.values.join(&self.separator))
546        }
547    }
548
549    fn clone_box(&self) -> Box<dyn AggregateState> {
550        Box::new(StringAggState {
551            values: self.values.clone(),
552            separator: self.separator.clone(),
553        })
554    }
555
556    fn reset(&mut self) {
557        self.values.clear();
558    }
559}
560
561// ============= MEDIAN Implementation =============
562
563struct MedianFunction;
564
565impl AggregateFunction for MedianFunction {
566    fn name(&self) -> &str {
567        "MEDIAN"
568    }
569
570    fn description(&self) -> &str {
571        "Calculate the median (middle value) of numeric values"
572    }
573
574    fn create_state(&self) -> Box<dyn AggregateState> {
575        Box::new(CollectorState {
576            values: Vec::new(),
577            function_type: CollectorFunction::Median,
578        })
579    }
580}
581
582// ============= MODE Implementation =============
583
584struct ModeFunction;
585
586impl AggregateFunction for ModeFunction {
587    fn name(&self) -> &str {
588        "MODE"
589    }
590
591    fn description(&self) -> &str {
592        "Find the most frequently occurring value"
593    }
594
595    fn create_state(&self) -> Box<dyn AggregateState> {
596        Box::new(CollectorState {
597            values: Vec::new(),
598            function_type: CollectorFunction::Mode,
599        })
600    }
601}
602
603// ============= STDDEV Implementation =============
604
605struct StdDevFunction;
606
607impl AggregateFunction for StdDevFunction {
608    fn name(&self) -> &str {
609        "STDDEV"
610    }
611
612    fn description(&self) -> &str {
613        "Calculate the sample standard deviation"
614    }
615
616    fn create_state(&self) -> Box<dyn AggregateState> {
617        Box::new(CollectorState {
618            values: Vec::new(),
619            function_type: CollectorFunction::StdDev,
620        })
621    }
622}
623
624// ============= STDDEV_POP Implementation =============
625
626struct StdDevPFunction;
627
628impl AggregateFunction for StdDevPFunction {
629    fn name(&self) -> &str {
630        "STDDEV_POP"
631    }
632
633    fn description(&self) -> &str {
634        "Calculate the population standard deviation"
635    }
636
637    fn create_state(&self) -> Box<dyn AggregateState> {
638        Box::new(CollectorState {
639            values: Vec::new(),
640            function_type: CollectorFunction::StdDevP,
641        })
642    }
643}
644
645// ============= VARIANCE Implementation =============
646
647struct VarianceFunction;
648
649impl AggregateFunction for VarianceFunction {
650    fn name(&self) -> &str {
651        "VARIANCE"
652    }
653
654    fn description(&self) -> &str {
655        "Calculate the sample variance"
656    }
657
658    fn create_state(&self) -> Box<dyn AggregateState> {
659        Box::new(CollectorState {
660            values: Vec::new(),
661            function_type: CollectorFunction::Variance,
662        })
663    }
664}
665
666// ============= VARIANCE_POP Implementation =============
667
668struct VariancePFunction;
669
670impl AggregateFunction for VariancePFunction {
671    fn name(&self) -> &str {
672        "VARIANCE_POP"
673    }
674
675    fn description(&self) -> &str {
676        "Calculate the population variance"
677    }
678
679    fn create_state(&self) -> Box<dyn AggregateState> {
680        Box::new(CollectorState {
681            values: Vec::new(),
682            function_type: CollectorFunction::VarianceP,
683        })
684    }
685}
686
687// ============= PERCENTILE Implementation =============
688
689struct PercentileFunction;
690
691impl AggregateFunction for PercentileFunction {
692    fn name(&self) -> &str {
693        "PERCENTILE"
694    }
695
696    fn description(&self) -> &str {
697        "Calculate the nth percentile of values"
698    }
699
700    fn create_state(&self) -> Box<dyn AggregateState> {
701        Box::new(PercentileState {
702            values: Vec::new(),
703            percentile: 50.0, // Default to median
704        })
705    }
706
707    fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
708        // PERCENTILE takes the percentile value as a parameter
709        if params.is_empty() {
710            return Ok(Box::new(PercentileFunction));
711        }
712
713        let percentile = match &params[0] {
714            DataValue::Integer(i) => *i as f64,
715            DataValue::Float(f) => *f,
716            _ => {
717                return Err(anyhow!(
718                    "PERCENTILE parameter must be a number between 0 and 100"
719                ))
720            }
721        };
722
723        if percentile < 0.0 || percentile > 100.0 {
724            return Err(anyhow!("PERCENTILE must be between 0 and 100"));
725        }
726
727        Ok(Box::new(PercentileWithParam { percentile }))
728    }
729}
730
731struct PercentileWithParam {
732    percentile: f64,
733}
734
735impl AggregateFunction for PercentileWithParam {
736    fn name(&self) -> &str {
737        "PERCENTILE"
738    }
739
740    fn description(&self) -> &str {
741        "Calculate the nth percentile of values"
742    }
743
744    fn create_state(&self) -> Box<dyn AggregateState> {
745        Box::new(PercentileState {
746            values: Vec::new(),
747            percentile: self.percentile,
748        })
749    }
750}
751
752// ============= Collector State for functions that need all values =============
753
754enum CollectorFunction {
755    Median,
756    Mode,
757    StdDev,    // Sample standard deviation
758    StdDevP,   // Population standard deviation
759    Variance,  // Sample variance
760    VarianceP, // Population variance
761}
762
763struct CollectorState {
764    values: Vec<f64>,
765    function_type: CollectorFunction,
766}
767
768impl AggregateState for CollectorState {
769    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
770        match value {
771            DataValue::Null => Ok(()), // Skip nulls
772            DataValue::Integer(n) => {
773                self.values.push(*n as f64);
774                Ok(())
775            }
776            DataValue::Float(f) => {
777                self.values.push(*f);
778                Ok(())
779            }
780            _ => match self.function_type {
781                CollectorFunction::Mode => {
782                    // Mode can work with non-numeric types, but we'll handle that separately
783                    Err(anyhow!("MODE currently only supports numeric values"))
784                }
785                _ => Err(anyhow!("Statistical functions require numeric values")),
786            },
787        }
788    }
789
790    fn finalize(self: Box<Self>) -> DataValue {
791        if self.values.is_empty() {
792            return DataValue::Null;
793        }
794
795        match self.function_type {
796            CollectorFunction::Median => {
797                let mut sorted = self.values.clone();
798                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
799                let len = sorted.len();
800                if len % 2 == 0 {
801                    DataValue::Float((sorted[len / 2 - 1] + sorted[len / 2]) / 2.0)
802                } else {
803                    DataValue::Float(sorted[len / 2])
804                }
805            }
806            CollectorFunction::Mode => {
807                use std::collections::HashMap;
808                let mut counts = HashMap::new();
809                for value in &self.values {
810                    *counts.entry(value.to_bits()).or_insert(0) += 1;
811                }
812                if let Some((bits, _)) = counts.iter().max_by_key(|&(_, count)| count) {
813                    DataValue::Float(f64::from_bits(*bits))
814                } else {
815                    DataValue::Null
816                }
817            }
818            CollectorFunction::StdDev | CollectorFunction::Variance => {
819                // Sample standard deviation and variance
820                if self.values.len() < 2 {
821                    return DataValue::Null;
822                }
823                let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
824                let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
825                    / (self.values.len() - 1) as f64; // N-1 for sample
826
827                match self.function_type {
828                    CollectorFunction::StdDev => DataValue::Float(variance.sqrt()),
829                    CollectorFunction::Variance => DataValue::Float(variance),
830                    _ => unreachable!(),
831                }
832            }
833            CollectorFunction::StdDevP | CollectorFunction::VarianceP => {
834                // Population standard deviation and variance
835                let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
836                let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
837                    / self.values.len() as f64; // N for population
838
839                match self.function_type {
840                    CollectorFunction::StdDevP => DataValue::Float(variance.sqrt()),
841                    CollectorFunction::VarianceP => DataValue::Float(variance),
842                    _ => unreachable!(),
843                }
844            }
845        }
846    }
847
848    fn clone_box(&self) -> Box<dyn AggregateState> {
849        Box::new(CollectorState {
850            values: self.values.clone(),
851            function_type: match self.function_type {
852                CollectorFunction::Median => CollectorFunction::Median,
853                CollectorFunction::Mode => CollectorFunction::Mode,
854                CollectorFunction::StdDev => CollectorFunction::StdDev,
855                CollectorFunction::StdDevP => CollectorFunction::StdDevP,
856                CollectorFunction::Variance => CollectorFunction::Variance,
857                CollectorFunction::VarianceP => CollectorFunction::VarianceP,
858            },
859        })
860    }
861
862    fn reset(&mut self) {
863        self.values.clear();
864    }
865}
866
867// ============= Percentile State =============
868
869struct PercentileState {
870    values: Vec<f64>,
871    percentile: f64,
872}
873
874impl AggregateState for PercentileState {
875    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
876        match value {
877            DataValue::Null => Ok(()), // Skip nulls
878            DataValue::Integer(n) => {
879                self.values.push(*n as f64);
880                Ok(())
881            }
882            DataValue::Float(f) => {
883                self.values.push(*f);
884                Ok(())
885            }
886            _ => Err(anyhow!("PERCENTILE requires numeric values")),
887        }
888    }
889
890    fn finalize(self: Box<Self>) -> DataValue {
891        if self.values.is_empty() {
892            return DataValue::Null;
893        }
894
895        let mut sorted = self.values.clone();
896        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
897
898        // Calculate the position in the sorted array
899        let position = (self.percentile / 100.0) * (sorted.len() - 1) as f64;
900        let lower = position.floor() as usize;
901        let upper = position.ceil() as usize;
902
903        if lower == upper {
904            DataValue::Float(sorted[lower])
905        } else {
906            // Linear interpolation between two values
907            let weight = position - lower as f64;
908            DataValue::Float(sorted[lower] * (1.0 - weight) + sorted[upper] * weight)
909        }
910    }
911
912    fn clone_box(&self) -> Box<dyn AggregateState> {
913        Box::new(PercentileState {
914            values: self.values.clone(),
915            percentile: self.percentile,
916        })
917    }
918
919    fn reset(&mut self) {
920        self.values.clear();
921    }
922}
923
924#[cfg(test)]
925mod tests {
926    use super::*;
927
928    #[test]
929    fn test_registry_creation() {
930        let registry = AggregateFunctionRegistry::new();
931        assert!(registry.contains("COUNT"));
932        assert!(registry.contains("SUM"));
933        assert!(registry.contains("AVG"));
934        assert!(registry.contains("MIN"));
935        assert!(registry.contains("MAX"));
936        assert!(registry.contains("STRING_AGG"));
937    }
938
939    #[test]
940    fn test_count_aggregate() {
941        let func = CountFunction;
942        let mut state = func.create_state();
943
944        state.accumulate(&DataValue::Integer(1)).unwrap();
945        state.accumulate(&DataValue::Null).unwrap();
946        state.accumulate(&DataValue::Integer(3)).unwrap();
947
948        let result = state.finalize();
949        assert_eq!(result, DataValue::Integer(2));
950    }
951
952    #[test]
953    fn test_string_agg() {
954        let func = StringAggFunction::with_separator(", ".to_string());
955        let mut state = func.create_state();
956
957        state
958            .accumulate(&DataValue::String("apple".to_string()))
959            .unwrap();
960        state
961            .accumulate(&DataValue::String("banana".to_string()))
962            .unwrap();
963        state
964            .accumulate(&DataValue::String("cherry".to_string()))
965            .unwrap();
966
967        let result = state.finalize();
968        assert_eq!(
969            result,
970            DataValue::String("apple, banana, cherry".to_string())
971        );
972    }
973}