sql_cli/sql/aggregate_functions/
mod.rs

1// Aggregate Function Registry
2// Provides a clean API for group-based aggregate computations
3// Moves all aggregate logic out of the evaluator into a registry pattern
4
5use anyhow::{anyhow, Result};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::data::datatable::DataValue;
10
11/// State maintained during aggregation
12/// Each aggregate function manages its own state type
13pub trait AggregateState: Send + Sync {
14    /// Add a value to the aggregate
15    fn accumulate(&mut self, value: &DataValue) -> Result<()>;
16
17    /// Finalize and return the aggregate result
18    fn finalize(self: Box<Self>) -> DataValue;
19
20    /// Create a new instance of this state
21    fn clone_box(&self) -> Box<dyn AggregateState>;
22
23    /// Reset the state for reuse
24    fn reset(&mut self);
25}
26
27/// Aggregate function trait
28/// Each aggregate function (SUM, COUNT, AVG, etc.) implements this
29pub trait AggregateFunction: Send + Sync {
30    /// Function name (e.g., "SUM", "COUNT", "STRING_AGG")
31    fn name(&self) -> &str;
32
33    /// Description for help system
34    fn description(&self) -> &str;
35
36    /// Create initial state for this aggregate
37    fn create_state(&self) -> Box<dyn AggregateState>;
38
39    /// Does this aggregate support DISTINCT?
40    fn supports_distinct(&self) -> bool {
41        true // Most aggregates should support DISTINCT
42    }
43
44    /// For aggregates with parameters (like STRING_AGG separator)
45    fn set_parameters(&self, _params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
46        // Default implementation for aggregates without parameters
47        Ok(Box::new(DummyClone(self.name().to_string())))
48    }
49}
50
51// Dummy clone helper for default implementation
52struct DummyClone(String);
53impl AggregateFunction for DummyClone {
54    fn name(&self) -> &str {
55        &self.0
56    }
57    fn description(&self) -> &str {
58        ""
59    }
60    fn create_state(&self) -> Box<dyn AggregateState> {
61        panic!("DummyClone should not be used")
62    }
63}
64
65/// Registry for aggregate functions
66pub struct AggregateFunctionRegistry {
67    functions: HashMap<String, Arc<Box<dyn AggregateFunction>>>,
68}
69
70impl AggregateFunctionRegistry {
71    pub fn new() -> Self {
72        let mut registry = Self {
73            functions: HashMap::new(),
74        };
75        registry.register_builtin_functions();
76        registry
77    }
78
79    /// Register an aggregate function
80    pub fn register(&mut self, function: Box<dyn AggregateFunction>) {
81        let name = function.name().to_uppercase();
82        self.functions.insert(name, Arc::new(function));
83    }
84
85    /// Get an aggregate function by name
86    pub fn get(&self, name: &str) -> Option<Arc<Box<dyn AggregateFunction>>> {
87        self.functions.get(&name.to_uppercase()).cloned()
88    }
89
90    /// Check if a function exists
91    pub fn contains(&self, name: &str) -> bool {
92        self.functions.contains_key(&name.to_uppercase())
93    }
94
95    /// List all registered functions
96    pub fn list_functions(&self) -> Vec<String> {
97        self.functions.keys().cloned().collect()
98    }
99
100    /// Register built-in aggregate functions
101    fn register_builtin_functions(&mut self) {
102        // Basic aggregates
103        self.register(Box::new(CountFunction));
104        self.register(Box::new(SumFunction));
105        self.register(Box::new(AvgFunction));
106        self.register(Box::new(MinFunction));
107        self.register(Box::new(MaxFunction));
108
109        // String aggregates
110        self.register(Box::new(StringAggFunction::new()));
111
112        // Statistical aggregates (to be implemented)
113        // self.register(Box::new(StdDevFunction));
114        // self.register(Box::new(VarianceFunction));
115    }
116}
117
118// ============= COUNT Implementation =============
119
120struct CountFunction;
121
122impl AggregateFunction for CountFunction {
123    fn name(&self) -> &str {
124        "COUNT"
125    }
126
127    fn description(&self) -> &str {
128        "Count the number of non-null values or rows"
129    }
130
131    fn create_state(&self) -> Box<dyn AggregateState> {
132        Box::new(CountState { count: 0 })
133    }
134}
135
136struct CountState {
137    count: i64,
138}
139
140impl AggregateState for CountState {
141    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
142        // COUNT(*) counts all rows, COUNT(column) counts non-nulls
143        if !matches!(value, DataValue::Null) {
144            self.count += 1;
145        }
146        Ok(())
147    }
148
149    fn finalize(self: Box<Self>) -> DataValue {
150        DataValue::Integer(self.count)
151    }
152
153    fn clone_box(&self) -> Box<dyn AggregateState> {
154        Box::new(CountState { count: self.count })
155    }
156
157    fn reset(&mut self) {
158        self.count = 0;
159    }
160}
161
162// ============= SUM Implementation =============
163
164struct SumFunction;
165
166impl AggregateFunction for SumFunction {
167    fn name(&self) -> &str {
168        "SUM"
169    }
170
171    fn description(&self) -> &str {
172        "Calculate the sum of values"
173    }
174
175    fn create_state(&self) -> Box<dyn AggregateState> {
176        Box::new(SumState {
177            int_sum: None,
178            float_sum: None,
179            has_values: false,
180        })
181    }
182}
183
184struct SumState {
185    int_sum: Option<i64>,
186    float_sum: Option<f64>,
187    has_values: bool,
188}
189
190impl AggregateState for SumState {
191    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
192        match value {
193            DataValue::Null => Ok(()), // Skip nulls
194            DataValue::Integer(n) => {
195                self.has_values = true;
196                if let Some(ref mut sum) = self.int_sum {
197                    *sum = sum.saturating_add(*n);
198                } else if let Some(ref mut fsum) = self.float_sum {
199                    *fsum += *n as f64;
200                } else {
201                    self.int_sum = Some(*n);
202                }
203                Ok(())
204            }
205            DataValue::Float(f) => {
206                self.has_values = true;
207                // Once we have a float, convert everything to float
208                if let Some(isum) = self.int_sum.take() {
209                    self.float_sum = Some(isum as f64 + f);
210                } else if let Some(ref mut fsum) = self.float_sum {
211                    *fsum += f;
212                } else {
213                    self.float_sum = Some(*f);
214                }
215                Ok(())
216            }
217            _ => Err(anyhow!("Cannot sum non-numeric value")),
218        }
219    }
220
221    fn finalize(self: Box<Self>) -> DataValue {
222        if !self.has_values {
223            return DataValue::Null;
224        }
225
226        if let Some(fsum) = self.float_sum {
227            DataValue::Float(fsum)
228        } else if let Some(isum) = self.int_sum {
229            DataValue::Integer(isum)
230        } else {
231            DataValue::Null
232        }
233    }
234
235    fn clone_box(&self) -> Box<dyn AggregateState> {
236        Box::new(SumState {
237            int_sum: self.int_sum,
238            float_sum: self.float_sum,
239            has_values: self.has_values,
240        })
241    }
242
243    fn reset(&mut self) {
244        self.int_sum = None;
245        self.float_sum = None;
246        self.has_values = false;
247    }
248}
249
250// ============= AVG Implementation =============
251
252struct AvgFunction;
253
254impl AggregateFunction for AvgFunction {
255    fn name(&self) -> &str {
256        "AVG"
257    }
258
259    fn description(&self) -> &str {
260        "Calculate the average of values"
261    }
262
263    fn create_state(&self) -> Box<dyn AggregateState> {
264        Box::new(AvgState {
265            sum: SumState {
266                int_sum: None,
267                float_sum: None,
268                has_values: false,
269            },
270            count: 0,
271        })
272    }
273}
274
275struct AvgState {
276    sum: SumState,
277    count: i64,
278}
279
280impl AggregateState for AvgState {
281    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
282        if !matches!(value, DataValue::Null) {
283            self.sum.accumulate(value)?;
284            self.count += 1;
285        }
286        Ok(())
287    }
288
289    fn finalize(self: Box<Self>) -> DataValue {
290        if self.count == 0 {
291            return DataValue::Null;
292        }
293
294        let sum = Box::new(self.sum).finalize();
295        match sum {
296            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
297            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
298            _ => DataValue::Null,
299        }
300    }
301
302    fn clone_box(&self) -> Box<dyn AggregateState> {
303        Box::new(AvgState {
304            sum: SumState {
305                int_sum: self.sum.int_sum,
306                float_sum: self.sum.float_sum,
307                has_values: self.sum.has_values,
308            },
309            count: self.count,
310        })
311    }
312
313    fn reset(&mut self) {
314        self.sum.reset();
315        self.count = 0;
316    }
317}
318
319// ============= MIN Implementation =============
320
321struct MinFunction;
322
323impl AggregateFunction for MinFunction {
324    fn name(&self) -> &str {
325        "MIN"
326    }
327
328    fn description(&self) -> &str {
329        "Find the minimum value"
330    }
331
332    fn create_state(&self) -> Box<dyn AggregateState> {
333        Box::new(MinMaxState {
334            is_min: true,
335            current: None,
336        })
337    }
338}
339
340// ============= MAX Implementation =============
341
342struct MaxFunction;
343
344impl AggregateFunction for MaxFunction {
345    fn name(&self) -> &str {
346        "MAX"
347    }
348
349    fn description(&self) -> &str {
350        "Find the maximum value"
351    }
352
353    fn create_state(&self) -> Box<dyn AggregateState> {
354        Box::new(MinMaxState {
355            is_min: false,
356            current: None,
357        })
358    }
359}
360
361struct MinMaxState {
362    is_min: bool,
363    current: Option<DataValue>,
364}
365
366impl AggregateState for MinMaxState {
367    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
368        if matches!(value, DataValue::Null) {
369            return Ok(());
370        }
371
372        match &self.current {
373            None => {
374                self.current = Some(value.clone());
375            }
376            Some(current) => {
377                let should_update = if self.is_min {
378                    value < current
379                } else {
380                    value > current
381                };
382
383                if should_update {
384                    self.current = Some(value.clone());
385                }
386            }
387        }
388
389        Ok(())
390    }
391
392    fn finalize(self: Box<Self>) -> DataValue {
393        self.current.unwrap_or(DataValue::Null)
394    }
395
396    fn clone_box(&self) -> Box<dyn AggregateState> {
397        Box::new(MinMaxState {
398            is_min: self.is_min,
399            current: self.current.clone(),
400        })
401    }
402
403    fn reset(&mut self) {
404        self.current = None;
405    }
406}
407
408// ============= STRING_AGG Implementation =============
409
410struct StringAggFunction {
411    separator: String,
412}
413
414impl StringAggFunction {
415    fn new() -> Self {
416        Self {
417            separator: ",".to_string(), // Default separator
418        }
419    }
420
421    fn with_separator(separator: String) -> Self {
422        Self { separator }
423    }
424}
425
426impl AggregateFunction for StringAggFunction {
427    fn name(&self) -> &str {
428        "STRING_AGG"
429    }
430
431    fn description(&self) -> &str {
432        "Concatenate strings with a separator"
433    }
434
435    fn create_state(&self) -> Box<dyn AggregateState> {
436        Box::new(StringAggState {
437            values: Vec::new(),
438            separator: self.separator.clone(),
439        })
440    }
441
442    fn set_parameters(&self, params: &[DataValue]) -> Result<Box<dyn AggregateFunction>> {
443        // STRING_AGG takes a separator as second parameter
444        if params.is_empty() {
445            return Ok(Box::new(StringAggFunction::new()));
446        }
447
448        let separator = match &params[0] {
449            DataValue::String(s) => s.clone(),
450            DataValue::InternedString(s) => s.to_string(),
451            _ => return Err(anyhow!("STRING_AGG separator must be a string")),
452        };
453
454        Ok(Box::new(StringAggFunction::with_separator(separator)))
455    }
456}
457
458struct StringAggState {
459    values: Vec<String>,
460    separator: String,
461}
462
463impl AggregateState for StringAggState {
464    fn accumulate(&mut self, value: &DataValue) -> Result<()> {
465        match value {
466            DataValue::Null => Ok(()), // Skip nulls
467            DataValue::String(s) => {
468                self.values.push(s.clone());
469                Ok(())
470            }
471            DataValue::InternedString(s) => {
472                self.values.push(s.to_string());
473                Ok(())
474            }
475            DataValue::Integer(n) => {
476                self.values.push(n.to_string());
477                Ok(())
478            }
479            DataValue::Float(f) => {
480                self.values.push(f.to_string());
481                Ok(())
482            }
483            DataValue::Boolean(b) => {
484                self.values.push(b.to_string());
485                Ok(())
486            }
487            DataValue::DateTime(dt) => {
488                self.values.push(dt.to_string());
489                Ok(())
490            }
491        }
492    }
493
494    fn finalize(self: Box<Self>) -> DataValue {
495        if self.values.is_empty() {
496            DataValue::Null
497        } else {
498            DataValue::String(self.values.join(&self.separator))
499        }
500    }
501
502    fn clone_box(&self) -> Box<dyn AggregateState> {
503        Box::new(StringAggState {
504            values: self.values.clone(),
505            separator: self.separator.clone(),
506        })
507    }
508
509    fn reset(&mut self) {
510        self.values.clear();
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_registry_creation() {
520        let registry = AggregateFunctionRegistry::new();
521        assert!(registry.contains("COUNT"));
522        assert!(registry.contains("SUM"));
523        assert!(registry.contains("AVG"));
524        assert!(registry.contains("MIN"));
525        assert!(registry.contains("MAX"));
526        assert!(registry.contains("STRING_AGG"));
527    }
528
529    #[test]
530    fn test_count_aggregate() {
531        let func = CountFunction;
532        let mut state = func.create_state();
533
534        state.accumulate(&DataValue::Integer(1)).unwrap();
535        state.accumulate(&DataValue::Null).unwrap();
536        state.accumulate(&DataValue::Integer(3)).unwrap();
537
538        let result = state.finalize();
539        assert_eq!(result, DataValue::Integer(2));
540    }
541
542    #[test]
543    fn test_string_agg() {
544        let func = StringAggFunction::with_separator(", ".to_string());
545        let mut state = func.create_state();
546
547        state
548            .accumulate(&DataValue::String("apple".to_string()))
549            .unwrap();
550        state
551            .accumulate(&DataValue::String("banana".to_string()))
552            .unwrap();
553        state
554            .accumulate(&DataValue::String("cherry".to_string()))
555            .unwrap();
556
557        let result = state.finalize();
558        assert_eq!(
559            result,
560            DataValue::String("apple, banana, cherry".to_string())
561        );
562    }
563}