sql_cli/sql/aggregates/
mod.rs

1//! Aggregate functions for GROUP BY operations
2//!
3//! This module provides SQL aggregate functions like SUM, AVG, COUNT, MIN, MAX
4//! that work with the `DataView` partitioning system for efficient GROUP BY queries.
5
6use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13/// State maintained during aggregation
14#[derive(Debug, Clone)]
15pub enum AggregateState {
16    Count(i64),
17    Sum(SumState),
18    Avg(AvgState),
19    MinMax(MinMaxState),
20    Variance(VarianceState),
21    CollectList(Vec<DataValue>),
22    Analytics(analytics::AnalyticsState),
23}
24
25/// State for SUM aggregation
26#[derive(Debug, Clone)]
27pub struct SumState {
28    pub int_sum: Option<i64>,
29    pub float_sum: Option<f64>,
30    pub has_values: bool,
31}
32
33impl Default for SumState {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl SumState {
40    #[must_use]
41    pub fn new() -> Self {
42        Self {
43            int_sum: None,
44            float_sum: None,
45            has_values: false,
46        }
47    }
48
49    pub fn add(&mut self, value: &DataValue) -> Result<()> {
50        match value {
51            DataValue::Null => Ok(()), // Skip nulls
52            DataValue::Integer(n) => {
53                self.has_values = true;
54                if let Some(ref mut sum) = self.int_sum {
55                    *sum = sum.saturating_add(*n);
56                } else if let Some(ref mut fsum) = self.float_sum {
57                    *fsum += *n as f64;
58                } else {
59                    self.int_sum = Some(*n);
60                }
61                Ok(())
62            }
63            DataValue::Float(f) => {
64                self.has_values = true;
65                // Once we have a float, convert everything to float
66                if let Some(isum) = self.int_sum.take() {
67                    self.float_sum = Some(isum as f64 + f);
68                } else if let Some(ref mut fsum) = self.float_sum {
69                    *fsum += f;
70                } else {
71                    self.float_sum = Some(*f);
72                }
73                Ok(())
74            }
75            _ => Err(anyhow!("Cannot sum non-numeric value")),
76        }
77    }
78
79    #[must_use]
80    pub fn finalize(self) -> DataValue {
81        if !self.has_values {
82            return DataValue::Null;
83        }
84
85        if let Some(fsum) = self.float_sum {
86            DataValue::Float(fsum)
87        } else if let Some(isum) = self.int_sum {
88            DataValue::Integer(isum)
89        } else {
90            DataValue::Null
91        }
92    }
93}
94
95/// State for AVG aggregation
96#[derive(Debug, Clone)]
97pub struct AvgState {
98    pub sum: SumState,
99    pub count: i64,
100}
101
102impl Default for AvgState {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl AvgState {
109    #[must_use]
110    pub fn new() -> Self {
111        Self {
112            sum: SumState::new(),
113            count: 0,
114        }
115    }
116
117    pub fn add(&mut self, value: &DataValue) -> Result<()> {
118        if !matches!(value, DataValue::Null) {
119            self.sum.add(value)?;
120            self.count += 1;
121        }
122        Ok(())
123    }
124
125    #[must_use]
126    pub fn finalize(self) -> DataValue {
127        if self.count == 0 {
128            return DataValue::Null;
129        }
130
131        let sum = self.sum.finalize();
132        match sum {
133            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
134            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
135            _ => DataValue::Null,
136        }
137    }
138}
139
140/// State for MIN/MAX aggregation
141#[derive(Debug, Clone)]
142pub struct MinMaxState {
143    pub is_min: bool,
144    pub current: Option<DataValue>,
145}
146
147impl MinMaxState {
148    #[must_use]
149    pub fn new(is_min: bool) -> Self {
150        Self {
151            is_min,
152            current: None,
153        }
154    }
155
156    pub fn add(&mut self, value: &DataValue) -> Result<()> {
157        if matches!(value, DataValue::Null) {
158            return Ok(());
159        }
160
161        if let Some(ref current) = self.current {
162            let should_update = if self.is_min {
163                value < current
164            } else {
165                value > current
166            };
167
168            if should_update {
169                self.current = Some(value.clone());
170            }
171        } else {
172            self.current = Some(value.clone());
173        }
174
175        Ok(())
176    }
177
178    #[must_use]
179    pub fn finalize(self) -> DataValue {
180        self.current.unwrap_or(DataValue::Null)
181    }
182}
183
184/// State for VARIANCE/STDDEV aggregation
185#[derive(Debug, Clone)]
186pub struct VarianceState {
187    pub sum: f64,
188    pub sum_of_squares: f64,
189    pub count: i64,
190}
191
192impl Default for VarianceState {
193    fn default() -> Self {
194        Self::new()
195    }
196}
197
198impl VarianceState {
199    #[must_use]
200    pub fn new() -> Self {
201        Self {
202            sum: 0.0,
203            sum_of_squares: 0.0,
204            count: 0,
205        }
206    }
207
208    pub fn add(&mut self, value: &DataValue) -> Result<()> {
209        match value {
210            DataValue::Null => Ok(()), // Skip nulls
211            DataValue::Integer(n) => {
212                let f = *n as f64;
213                self.sum += f;
214                self.sum_of_squares += f * f;
215                self.count += 1;
216                Ok(())
217            }
218            DataValue::Float(f) => {
219                self.sum += f;
220                self.sum_of_squares += f * f;
221                self.count += 1;
222                Ok(())
223            }
224            _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
225        }
226    }
227
228    #[must_use]
229    pub fn variance(&self) -> f64 {
230        if self.count <= 1 {
231            return 0.0;
232        }
233        let mean = self.sum / self.count as f64;
234        (self.sum_of_squares / self.count as f64) - (mean * mean)
235    }
236
237    #[must_use]
238    pub fn stddev(&self) -> f64 {
239        self.variance().sqrt()
240    }
241
242    #[must_use]
243    pub fn finalize_variance(self) -> DataValue {
244        if self.count == 0 {
245            DataValue::Null
246        } else {
247            DataValue::Float(self.variance())
248        }
249    }
250
251    #[must_use]
252    pub fn finalize_stddev(self) -> DataValue {
253        if self.count == 0 {
254            DataValue::Null
255        } else {
256            DataValue::Float(self.stddev())
257        }
258    }
259}
260
261/// Trait for all aggregate functions
262pub trait AggregateFunction: Send + Sync {
263    /// Name of the function (e.g., "SUM", "AVG")
264    fn name(&self) -> &str;
265
266    /// Initialize the aggregation state
267    fn init(&self) -> AggregateState;
268
269    /// Add a value to the aggregation
270    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
271
272    /// Finalize and return the result
273    fn finalize(&self, state: AggregateState) -> DataValue;
274
275    /// Check if this function requires numeric input
276    fn requires_numeric(&self) -> bool {
277        false
278    }
279}
280
281/// Registry of aggregate functions
282pub struct AggregateRegistry {
283    functions: Vec<Box<dyn AggregateFunction>>,
284}
285
286impl AggregateRegistry {
287    #[must_use]
288    pub fn new() -> Self {
289        use analytics::{
290            CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
291            RankFunction, SumsFunction,
292        };
293        use functions::{
294            AvgFunction, CountFunction, CountStarFunction, MaxFunction, MinFunction,
295            StdDevFunction, SumFunction, VarianceFunction,
296        };
297
298        let functions: Vec<Box<dyn AggregateFunction>> = vec![
299            Box::new(CountFunction),
300            Box::new(CountStarFunction),
301            Box::new(SumFunction),
302            Box::new(AvgFunction),
303            Box::new(MinFunction),
304            Box::new(MaxFunction),
305            Box::new(StdDevFunction),
306            Box::new(VarianceFunction),
307            // Analytics functions
308            Box::new(DeltasFunction),
309            Box::new(SumsFunction),
310            Box::new(MavgFunction),
311            Box::new(PctChangeFunction),
312            Box::new(RankFunction),
313            Box::new(CumMaxFunction),
314            Box::new(CumMinFunction),
315        ];
316
317        Self { functions }
318    }
319
320    #[must_use]
321    pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
322        let name_upper = name.to_uppercase();
323        self.functions
324            .iter()
325            .find(|f| f.name() == name_upper)
326            .map(std::convert::AsRef::as_ref)
327    }
328
329    #[must_use]
330    pub fn is_aggregate(&self, name: &str) -> bool {
331        self.get(name).is_some() || name.to_uppercase() == "COUNT" // COUNT(*) special case
332    }
333}
334
335impl Default for AggregateRegistry {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341/// Check if an expression contains aggregate functions
342pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
343    use crate::recursive_parser::SqlExpression;
344
345    match expr {
346        SqlExpression::FunctionCall { name, args, .. } => {
347            let registry = AggregateRegistry::new();
348            if registry.is_aggregate(name) {
349                return true;
350            }
351            // Check nested expressions
352            args.iter().any(contains_aggregate)
353        }
354        SqlExpression::BinaryOp { left, right, .. } => {
355            contains_aggregate(left) || contains_aggregate(right)
356        }
357        SqlExpression::Not { expr } => contains_aggregate(expr),
358        SqlExpression::CaseExpression {
359            when_branches,
360            else_branch,
361        } => {
362            when_branches.iter().any(|branch| {
363                contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
364            }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
365        }
366        _ => false,
367    }
368}