sql_cli/sql/aggregates/
mod.rs

1//! Aggregate functions for GROUP BY operations
2//!
3//! This module provides SQL aggregate functions like SUM, AVG, COUNT, MIN, MAX
4//! that work with the DataView partitioning system for efficient GROUP BY queries.
5
6use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13/// State maintained during aggregation
14#[derive(Debug, Clone)]
15pub enum AggregateState {
16    Count(i64),
17    Sum(SumState),
18    Avg(AvgState),
19    MinMax(MinMaxState),
20    CollectList(Vec<DataValue>),
21    Analytics(analytics::AnalyticsState),
22}
23
24/// State for SUM aggregation
25#[derive(Debug, Clone)]
26pub struct SumState {
27    pub int_sum: Option<i64>,
28    pub float_sum: Option<f64>,
29    pub has_values: bool,
30}
31
32impl SumState {
33    pub fn new() -> Self {
34        Self {
35            int_sum: None,
36            float_sum: None,
37            has_values: false,
38        }
39    }
40
41    pub fn add(&mut self, value: &DataValue) -> Result<()> {
42        match value {
43            DataValue::Null => Ok(()), // Skip nulls
44            DataValue::Integer(n) => {
45                self.has_values = true;
46                if let Some(ref mut sum) = self.int_sum {
47                    *sum = sum.saturating_add(*n);
48                } else if let Some(ref mut fsum) = self.float_sum {
49                    *fsum += *n as f64;
50                } else {
51                    self.int_sum = Some(*n);
52                }
53                Ok(())
54            }
55            DataValue::Float(f) => {
56                self.has_values = true;
57                // Once we have a float, convert everything to float
58                if let Some(isum) = self.int_sum.take() {
59                    self.float_sum = Some(isum as f64 + f);
60                } else if let Some(ref mut fsum) = self.float_sum {
61                    *fsum += f;
62                } else {
63                    self.float_sum = Some(*f);
64                }
65                Ok(())
66            }
67            _ => Err(anyhow!("Cannot sum non-numeric value")),
68        }
69    }
70
71    pub fn finalize(self) -> DataValue {
72        if !self.has_values {
73            return DataValue::Null;
74        }
75
76        if let Some(fsum) = self.float_sum {
77            DataValue::Float(fsum)
78        } else if let Some(isum) = self.int_sum {
79            DataValue::Integer(isum)
80        } else {
81            DataValue::Null
82        }
83    }
84}
85
86/// State for AVG aggregation
87#[derive(Debug, Clone)]
88pub struct AvgState {
89    pub sum: SumState,
90    pub count: i64,
91}
92
93impl AvgState {
94    pub fn new() -> Self {
95        Self {
96            sum: SumState::new(),
97            count: 0,
98        }
99    }
100
101    pub fn add(&mut self, value: &DataValue) -> Result<()> {
102        if !matches!(value, DataValue::Null) {
103            self.sum.add(value)?;
104            self.count += 1;
105        }
106        Ok(())
107    }
108
109    pub fn finalize(self) -> DataValue {
110        if self.count == 0 {
111            return DataValue::Null;
112        }
113
114        let sum = self.sum.finalize();
115        match sum {
116            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
117            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
118            _ => DataValue::Null,
119        }
120    }
121}
122
123/// State for MIN/MAX aggregation
124#[derive(Debug, Clone)]
125pub struct MinMaxState {
126    pub is_min: bool,
127    pub current: Option<DataValue>,
128}
129
130impl MinMaxState {
131    pub fn new(is_min: bool) -> Self {
132        Self {
133            is_min,
134            current: None,
135        }
136    }
137
138    pub fn add(&mut self, value: &DataValue) -> Result<()> {
139        if matches!(value, DataValue::Null) {
140            return Ok(());
141        }
142
143        if let Some(ref current) = self.current {
144            let should_update = if self.is_min {
145                value < current
146            } else {
147                value > current
148            };
149
150            if should_update {
151                self.current = Some(value.clone());
152            }
153        } else {
154            self.current = Some(value.clone());
155        }
156
157        Ok(())
158    }
159
160    pub fn finalize(self) -> DataValue {
161        self.current.unwrap_or(DataValue::Null)
162    }
163}
164
165/// Trait for all aggregate functions
166pub trait AggregateFunction: Send + Sync {
167    /// Name of the function (e.g., "SUM", "AVG")
168    fn name(&self) -> &str;
169
170    /// Initialize the aggregation state
171    fn init(&self) -> AggregateState;
172
173    /// Add a value to the aggregation
174    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
175
176    /// Finalize and return the result
177    fn finalize(&self, state: AggregateState) -> DataValue;
178
179    /// Check if this function requires numeric input
180    fn requires_numeric(&self) -> bool {
181        false
182    }
183}
184
185/// Registry of aggregate functions
186pub struct AggregateRegistry {
187    functions: Vec<Box<dyn AggregateFunction>>,
188}
189
190impl AggregateRegistry {
191    pub fn new() -> Self {
192        use analytics::*;
193        use functions::*;
194
195        let functions: Vec<Box<dyn AggregateFunction>> = vec![
196            Box::new(CountFunction),
197            Box::new(CountStarFunction),
198            Box::new(SumFunction),
199            Box::new(AvgFunction),
200            Box::new(MinFunction),
201            Box::new(MaxFunction),
202            // Analytics functions
203            Box::new(DeltasFunction),
204            Box::new(SumsFunction),
205            Box::new(MavgFunction),
206            Box::new(PctChangeFunction),
207            Box::new(RankFunction),
208            Box::new(CumMaxFunction),
209            Box::new(CumMinFunction),
210        ];
211
212        Self { functions }
213    }
214
215    pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
216        let name_upper = name.to_uppercase();
217        self.functions
218            .iter()
219            .find(|f| f.name() == name_upper)
220            .map(|f| f.as_ref())
221    }
222
223    pub fn is_aggregate(&self, name: &str) -> bool {
224        self.get(name).is_some() || name.to_uppercase() == "COUNT" // COUNT(*) special case
225    }
226}
227
228impl Default for AggregateRegistry {
229    fn default() -> Self {
230        Self::new()
231    }
232}
233
234/// Check if an expression contains aggregate functions
235pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
236    use crate::recursive_parser::SqlExpression;
237
238    match expr {
239        SqlExpression::FunctionCall { name, args } => {
240            let registry = AggregateRegistry::new();
241            if registry.is_aggregate(name) {
242                return true;
243            }
244            // Check nested expressions
245            args.iter().any(contains_aggregate)
246        }
247        SqlExpression::BinaryOp { left, right, .. } => {
248            contains_aggregate(left) || contains_aggregate(right)
249        }
250        SqlExpression::Not { expr } => contains_aggregate(expr),
251        SqlExpression::CaseExpression {
252            when_branches,
253            else_branch,
254        } => {
255            when_branches.iter().any(|branch| {
256                contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
257            }) || else_branch
258                .as_ref()
259                .map_or(false, |e| contains_aggregate(e))
260        }
261        _ => false,
262    }
263}