sql_cli/sql/aggregates/
mod.rs

1//! Aggregate functions for GROUP BY operations
2//!
3//! This module provides SQL aggregate functions like SUM, AVG, COUNT, MIN, MAX
4//! that work with the DataView partitioning system for efficient GROUP BY queries.
5
6use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod functions;
11
12/// State maintained during aggregation
13#[derive(Debug, Clone)]
14pub enum AggregateState {
15    Count(i64),
16    Sum(SumState),
17    Avg(AvgState),
18    MinMax(MinMaxState),
19    CollectList(Vec<DataValue>),
20}
21
22/// State for SUM aggregation
23#[derive(Debug, Clone)]
24pub struct SumState {
25    pub int_sum: Option<i64>,
26    pub float_sum: Option<f64>,
27    pub has_values: bool,
28}
29
30impl SumState {
31    pub fn new() -> Self {
32        Self {
33            int_sum: None,
34            float_sum: None,
35            has_values: false,
36        }
37    }
38
39    pub fn add(&mut self, value: &DataValue) -> Result<()> {
40        match value {
41            DataValue::Null => Ok(()), // Skip nulls
42            DataValue::Integer(n) => {
43                self.has_values = true;
44                if let Some(ref mut sum) = self.int_sum {
45                    *sum = sum.saturating_add(*n);
46                } else if let Some(ref mut fsum) = self.float_sum {
47                    *fsum += *n as f64;
48                } else {
49                    self.int_sum = Some(*n);
50                }
51                Ok(())
52            }
53            DataValue::Float(f) => {
54                self.has_values = true;
55                // Once we have a float, convert everything to float
56                if let Some(isum) = self.int_sum.take() {
57                    self.float_sum = Some(isum as f64 + f);
58                } else if let Some(ref mut fsum) = self.float_sum {
59                    *fsum += f;
60                } else {
61                    self.float_sum = Some(*f);
62                }
63                Ok(())
64            }
65            _ => Err(anyhow!("Cannot sum non-numeric value")),
66        }
67    }
68
69    pub fn finalize(self) -> DataValue {
70        if !self.has_values {
71            return DataValue::Null;
72        }
73
74        if let Some(fsum) = self.float_sum {
75            DataValue::Float(fsum)
76        } else if let Some(isum) = self.int_sum {
77            DataValue::Integer(isum)
78        } else {
79            DataValue::Null
80        }
81    }
82}
83
84/// State for AVG aggregation
85#[derive(Debug, Clone)]
86pub struct AvgState {
87    pub sum: SumState,
88    pub count: i64,
89}
90
91impl AvgState {
92    pub fn new() -> Self {
93        Self {
94            sum: SumState::new(),
95            count: 0,
96        }
97    }
98
99    pub fn add(&mut self, value: &DataValue) -> Result<()> {
100        if !matches!(value, DataValue::Null) {
101            self.sum.add(value)?;
102            self.count += 1;
103        }
104        Ok(())
105    }
106
107    pub fn finalize(self) -> DataValue {
108        if self.count == 0 {
109            return DataValue::Null;
110        }
111
112        let sum = self.sum.finalize();
113        match sum {
114            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
115            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
116            _ => DataValue::Null,
117        }
118    }
119}
120
121/// State for MIN/MAX aggregation
122#[derive(Debug, Clone)]
123pub struct MinMaxState {
124    pub is_min: bool,
125    pub current: Option<DataValue>,
126}
127
128impl MinMaxState {
129    pub fn new(is_min: bool) -> Self {
130        Self {
131            is_min,
132            current: None,
133        }
134    }
135
136    pub fn add(&mut self, value: &DataValue) -> Result<()> {
137        if matches!(value, DataValue::Null) {
138            return Ok(());
139        }
140
141        if let Some(ref current) = self.current {
142            let should_update = if self.is_min {
143                value < current
144            } else {
145                value > current
146            };
147
148            if should_update {
149                self.current = Some(value.clone());
150            }
151        } else {
152            self.current = Some(value.clone());
153        }
154
155        Ok(())
156    }
157
158    pub fn finalize(self) -> DataValue {
159        self.current.unwrap_or(DataValue::Null)
160    }
161}
162
163/// Trait for all aggregate functions
164pub trait AggregateFunction: Send + Sync {
165    /// Name of the function (e.g., "SUM", "AVG")
166    fn name(&self) -> &str;
167
168    /// Initialize the aggregation state
169    fn init(&self) -> AggregateState;
170
171    /// Add a value to the aggregation
172    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
173
174    /// Finalize and return the result
175    fn finalize(&self, state: AggregateState) -> DataValue;
176
177    /// Check if this function requires numeric input
178    fn requires_numeric(&self) -> bool {
179        false
180    }
181}
182
183/// Registry of aggregate functions
184pub struct AggregateRegistry {
185    functions: Vec<Box<dyn AggregateFunction>>,
186}
187
188impl AggregateRegistry {
189    pub fn new() -> Self {
190        use functions::*;
191
192        let functions: Vec<Box<dyn AggregateFunction>> = vec![
193            Box::new(CountFunction),
194            Box::new(CountStarFunction),
195            Box::new(SumFunction),
196            Box::new(AvgFunction),
197            Box::new(MinFunction),
198            Box::new(MaxFunction),
199        ];
200
201        Self { functions }
202    }
203
204    pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
205        let name_upper = name.to_uppercase();
206        self.functions
207            .iter()
208            .find(|f| f.name() == name_upper)
209            .map(|f| f.as_ref())
210    }
211
212    pub fn is_aggregate(&self, name: &str) -> bool {
213        self.get(name).is_some() || name.to_uppercase() == "COUNT" // COUNT(*) special case
214    }
215}
216
217impl Default for AggregateRegistry {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223/// Check if an expression contains aggregate functions
224pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
225    use crate::recursive_parser::SqlExpression;
226
227    match expr {
228        SqlExpression::FunctionCall { name, args } => {
229            let registry = AggregateRegistry::new();
230            if registry.is_aggregate(name) {
231                return true;
232            }
233            // Check nested expressions
234            args.iter().any(contains_aggregate)
235        }
236        SqlExpression::BinaryOp { left, right, .. } => {
237            contains_aggregate(left) || contains_aggregate(right)
238        }
239        SqlExpression::Not { expr } => contains_aggregate(expr),
240        SqlExpression::CaseExpression {
241            when_branches,
242            else_branch,
243        } => {
244            when_branches.iter().any(|branch| {
245                contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
246            }) || else_branch
247                .as_ref()
248                .map_or(false, |e| contains_aggregate(e))
249        }
250        _ => false,
251    }
252}