sql_cli/sql/aggregates/
mod.rs

1//! Aggregate functions for GROUP BY operations
2//!
3//! This module provides SQL aggregate functions like SUM, AVG, COUNT, MIN, MAX
4//! that work with the `DataView` partitioning system for efficient GROUP BY queries.
5
6use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13/// State maintained during aggregation
14#[derive(Debug, Clone)]
15pub enum AggregateState {
16    Count(i64),
17    Sum(SumState),
18    Avg(AvgState),
19    MinMax(MinMaxState),
20    Variance(VarianceState),
21    CollectList(Vec<DataValue>),
22    Percentile(PercentileState),
23    Mode(ModeState),
24    Analytics(analytics::AnalyticsState),
25    StringAgg(StringAggState),
26}
27
28/// State for SUM aggregation
29#[derive(Debug, Clone)]
30pub struct SumState {
31    pub int_sum: Option<i64>,
32    pub float_sum: Option<f64>,
33    pub has_values: bool,
34}
35
36impl Default for SumState {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl SumState {
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            int_sum: None,
47            float_sum: None,
48            has_values: false,
49        }
50    }
51
52    pub fn add(&mut self, value: &DataValue) -> Result<()> {
53        match value {
54            DataValue::Null => Ok(()), // Skip nulls
55            DataValue::Integer(n) => {
56                self.has_values = true;
57                if let Some(ref mut sum) = self.int_sum {
58                    *sum = sum.saturating_add(*n);
59                } else if let Some(ref mut fsum) = self.float_sum {
60                    *fsum += *n as f64;
61                } else {
62                    self.int_sum = Some(*n);
63                }
64                Ok(())
65            }
66            DataValue::Float(f) => {
67                self.has_values = true;
68                // Once we have a float, convert everything to float
69                if let Some(isum) = self.int_sum.take() {
70                    self.float_sum = Some(isum as f64 + f);
71                } else if let Some(ref mut fsum) = self.float_sum {
72                    *fsum += f;
73                } else {
74                    self.float_sum = Some(*f);
75                }
76                Ok(())
77            }
78            _ => Err(anyhow!("Cannot sum non-numeric value")),
79        }
80    }
81
82    #[must_use]
83    pub fn finalize(self) -> DataValue {
84        if !self.has_values {
85            return DataValue::Null;
86        }
87
88        if let Some(fsum) = self.float_sum {
89            DataValue::Float(fsum)
90        } else if let Some(isum) = self.int_sum {
91            DataValue::Integer(isum)
92        } else {
93            DataValue::Null
94        }
95    }
96}
97
98/// State for AVG aggregation
99#[derive(Debug, Clone)]
100pub struct AvgState {
101    pub sum: SumState,
102    pub count: i64,
103}
104
105impl Default for AvgState {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl AvgState {
112    #[must_use]
113    pub fn new() -> Self {
114        Self {
115            sum: SumState::new(),
116            count: 0,
117        }
118    }
119
120    pub fn add(&mut self, value: &DataValue) -> Result<()> {
121        if !matches!(value, DataValue::Null) {
122            self.sum.add(value)?;
123            self.count += 1;
124        }
125        Ok(())
126    }
127
128    #[must_use]
129    pub fn finalize(self) -> DataValue {
130        if self.count == 0 {
131            return DataValue::Null;
132        }
133
134        let sum = self.sum.finalize();
135        match sum {
136            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
137            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
138            _ => DataValue::Null,
139        }
140    }
141}
142
143/// State for MIN/MAX aggregation
144#[derive(Debug, Clone)]
145pub struct MinMaxState {
146    pub is_min: bool,
147    pub current: Option<DataValue>,
148}
149
150impl MinMaxState {
151    #[must_use]
152    pub fn new(is_min: bool) -> Self {
153        Self {
154            is_min,
155            current: None,
156        }
157    }
158
159    pub fn add(&mut self, value: &DataValue) -> Result<()> {
160        if matches!(value, DataValue::Null) {
161            return Ok(());
162        }
163
164        if let Some(ref current) = self.current {
165            let should_update = if self.is_min {
166                value < current
167            } else {
168                value > current
169            };
170
171            if should_update {
172                self.current = Some(value.clone());
173            }
174        } else {
175            self.current = Some(value.clone());
176        }
177
178        Ok(())
179    }
180
181    #[must_use]
182    pub fn finalize(self) -> DataValue {
183        self.current.unwrap_or(DataValue::Null)
184    }
185}
186
187/// State for VARIANCE/STDDEV aggregation
188#[derive(Debug, Clone)]
189pub struct VarianceState {
190    pub sum: f64,
191    pub sum_of_squares: f64,
192    pub count: i64,
193}
194
195impl Default for VarianceState {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201impl VarianceState {
202    #[must_use]
203    pub fn new() -> Self {
204        Self {
205            sum: 0.0,
206            sum_of_squares: 0.0,
207            count: 0,
208        }
209    }
210
211    pub fn add(&mut self, value: &DataValue) -> Result<()> {
212        match value {
213            DataValue::Null => Ok(()), // Skip nulls
214            DataValue::Integer(n) => {
215                let f = *n as f64;
216                self.sum += f;
217                self.sum_of_squares += f * f;
218                self.count += 1;
219                Ok(())
220            }
221            DataValue::Float(f) => {
222                self.sum += f;
223                self.sum_of_squares += f * f;
224                self.count += 1;
225                Ok(())
226            }
227            _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
228        }
229    }
230
231    #[must_use]
232    pub fn variance(&self) -> f64 {
233        if self.count <= 1 {
234            return 0.0;
235        }
236        let mean = self.sum / self.count as f64;
237        (self.sum_of_squares / self.count as f64) - (mean * mean)
238    }
239
240    #[must_use]
241    pub fn stddev(&self) -> f64 {
242        self.variance().sqrt()
243    }
244
245    #[must_use]
246    pub fn finalize_variance(self) -> DataValue {
247        if self.count == 0 {
248            DataValue::Null
249        } else {
250            DataValue::Float(self.variance())
251        }
252    }
253
254    #[must_use]
255    pub fn finalize_stddev(self) -> DataValue {
256        if self.count == 0 {
257            DataValue::Null
258        } else {
259            DataValue::Float(self.stddev())
260        }
261    }
262}
263
264/// State for PERCENTILE aggregation
265#[derive(Debug, Clone)]
266pub struct PercentileState {
267    pub values: Vec<DataValue>,
268    pub percentile: f64,
269}
270
271impl Default for PercentileState {
272    fn default() -> Self {
273        Self::new(50.0) // Default to median
274    }
275}
276
277impl PercentileState {
278    #[must_use]
279    pub fn new(percentile: f64) -> Self {
280        Self {
281            values: Vec::new(),
282            percentile: percentile.clamp(0.0, 100.0),
283        }
284    }
285
286    pub fn add(&mut self, value: &DataValue) -> Result<()> {
287        if !matches!(value, DataValue::Null) {
288            self.values.push(value.clone());
289        }
290        Ok(())
291    }
292
293    #[must_use]
294    pub fn finalize(mut self) -> DataValue {
295        if self.values.is_empty() {
296            return DataValue::Null;
297        }
298
299        // Sort values for percentile calculation
300        self.values.sort_by(|a, b| {
301            use std::cmp::Ordering;
302            match (a, b) {
303                (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
304                (DataValue::Float(a), DataValue::Float(b)) => {
305                    a.partial_cmp(b).unwrap_or(Ordering::Equal)
306                }
307                (DataValue::Integer(a), DataValue::Float(b)) => {
308                    (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
309                }
310                (DataValue::Float(a), DataValue::Integer(b)) => {
311                    a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
312                }
313                _ => Ordering::Equal,
314            }
315        });
316
317        let n = self.values.len();
318        if self.percentile == 0.0 {
319            return self.values[0].clone();
320        }
321        if self.percentile == 100.0 {
322            return self.values[n - 1].clone();
323        }
324
325        // Calculate percentile using linear interpolation
326        let pos = (self.percentile / 100.0) * ((n - 1) as f64);
327        let lower_idx = pos.floor() as usize;
328        let upper_idx = pos.ceil() as usize;
329
330        if lower_idx == upper_idx {
331            // Exact position
332            self.values[lower_idx].clone()
333        } else {
334            // Interpolate between two values
335            let fraction = pos - lower_idx as f64;
336            let lower_val = &self.values[lower_idx];
337            let upper_val = &self.values[upper_idx];
338
339            match (lower_val, upper_val) {
340                (DataValue::Integer(a), DataValue::Integer(b)) => {
341                    let result = *a as f64 + fraction * (*b - *a) as f64;
342                    if result.fract() == 0.0 {
343                        DataValue::Integer(result as i64)
344                    } else {
345                        DataValue::Float(result)
346                    }
347                }
348                (DataValue::Float(a), DataValue::Float(b)) => {
349                    DataValue::Float(a + fraction * (b - a))
350                }
351                (DataValue::Integer(a), DataValue::Float(b)) => {
352                    DataValue::Float(*a as f64 + fraction * (b - *a as f64))
353                }
354                (DataValue::Float(a), DataValue::Integer(b)) => {
355                    DataValue::Float(a + fraction * (*b as f64 - a))
356                }
357                // For non-numeric, return the lower value
358                _ => lower_val.clone(),
359            }
360        }
361    }
362}
363
364/// State for MODE aggregation (most frequent value)
365#[derive(Debug, Clone)]
366pub struct ModeState {
367    pub counts: std::collections::HashMap<String, (DataValue, i64)>,
368}
369
370impl Default for ModeState {
371    fn default() -> Self {
372        Self::new()
373    }
374}
375
376impl ModeState {
377    #[must_use]
378    pub fn new() -> Self {
379        Self {
380            counts: std::collections::HashMap::new(),
381        }
382    }
383
384    pub fn add(&mut self, value: &DataValue) -> Result<()> {
385        if matches!(value, DataValue::Null) {
386            return Ok(());
387        }
388
389        // Convert value to string for hashing, but keep original value for result
390        let key = match value {
391            DataValue::String(s) => s.clone(),
392            DataValue::InternedString(s) => s.to_string(),
393            DataValue::Integer(i) => i.to_string(),
394            DataValue::Float(f) => f.to_string(),
395            DataValue::Boolean(b) => b.to_string(),
396            DataValue::DateTime(dt) => dt.to_string(),
397            DataValue::Null => return Ok(()),
398        };
399
400        // Update count and store the original value
401        let entry = self.counts.entry(key).or_insert((value.clone(), 0));
402        entry.1 += 1;
403
404        Ok(())
405    }
406
407    #[must_use]
408    pub fn finalize(self) -> DataValue {
409        if self.counts.is_empty() {
410            return DataValue::Null;
411        }
412
413        // Find the value with the highest count
414        let max_entry = self.counts.iter().max_by_key(|(_, (_, count))| count);
415
416        match max_entry {
417            Some((_, (value, _count))) => value.clone(),
418            None => DataValue::Null,
419        }
420    }
421}
422
423/// Trait for all aggregate functions
424pub trait AggregateFunction: Send + Sync {
425    /// Name of the function (e.g., "SUM", "AVG")
426    fn name(&self) -> &str;
427
428    /// Initialize the aggregation state
429    fn init(&self) -> AggregateState;
430
431    /// Add a value to the aggregation
432    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
433
434    /// Finalize and return the result
435    fn finalize(&self, state: AggregateState) -> DataValue;
436
437    /// Check if this function requires numeric input
438    fn requires_numeric(&self) -> bool {
439        false
440    }
441}
442
443/// State for STRING_AGG aggregation
444#[derive(Debug, Clone)]
445pub struct StringAggState {
446    pub values: Vec<String>,
447    pub separator: String,
448}
449
450impl Default for StringAggState {
451    fn default() -> Self {
452        Self::new(",")
453    }
454}
455
456impl StringAggState {
457    #[must_use]
458    pub fn new(separator: &str) -> Self {
459        Self {
460            values: Vec::new(),
461            separator: separator.to_string(),
462        }
463    }
464
465    pub fn add(&mut self, value: &DataValue) -> Result<()> {
466        match value {
467            DataValue::Null => Ok(()), // Skip nulls
468            DataValue::String(s) => {
469                self.values.push(s.clone());
470                Ok(())
471            }
472            DataValue::InternedString(s) => {
473                self.values.push(s.to_string());
474                Ok(())
475            }
476            DataValue::Integer(n) => {
477                self.values.push(n.to_string());
478                Ok(())
479            }
480            DataValue::Float(f) => {
481                self.values.push(f.to_string());
482                Ok(())
483            }
484            DataValue::Boolean(b) => {
485                self.values.push(b.to_string());
486                Ok(())
487            }
488            DataValue::DateTime(dt) => {
489                self.values.push(dt.to_string());
490                Ok(())
491            }
492        }
493    }
494
495    #[must_use]
496    pub fn finalize(self) -> DataValue {
497        if self.values.is_empty() {
498            DataValue::Null
499        } else {
500            DataValue::String(self.values.join(&self.separator))
501        }
502    }
503}
504
505/// Registry of aggregate functions
506pub struct AggregateRegistry {
507    functions: Vec<Box<dyn AggregateFunction>>,
508}
509
510impl AggregateRegistry {
511    #[must_use]
512    pub fn new() -> Self {
513        use analytics::{
514            CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
515            RankFunction, SumsFunction,
516        };
517        use functions::{
518            AvgFunction, CountFunction, CountStarFunction, MaxFunction, MedianFunction,
519            MinFunction, ModeFunction, PercentileFunction, StdDevFunction, StringAggFunction,
520            SumFunction, VarianceFunction,
521        };
522
523        let functions: Vec<Box<dyn AggregateFunction>> = vec![
524            Box::new(CountFunction),
525            Box::new(CountStarFunction),
526            Box::new(SumFunction),
527            Box::new(AvgFunction),
528            Box::new(MinFunction),
529            Box::new(MaxFunction),
530            Box::new(StdDevFunction),
531            Box::new(VarianceFunction),
532            Box::new(MedianFunction),
533            Box::new(ModeFunction),
534            Box::new(PercentileFunction),
535            Box::new(StringAggFunction),
536            // Analytics functions
537            Box::new(DeltasFunction),
538            Box::new(SumsFunction),
539            Box::new(MavgFunction),
540            Box::new(PctChangeFunction),
541            Box::new(RankFunction),
542            Box::new(CumMaxFunction),
543            Box::new(CumMinFunction),
544        ];
545
546        Self { functions }
547    }
548
549    #[must_use]
550    pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
551        let name_upper = name.to_uppercase();
552        self.functions
553            .iter()
554            .find(|f| f.name() == name_upper)
555            .map(std::convert::AsRef::as_ref)
556    }
557
558    #[must_use]
559    pub fn is_aggregate(&self, name: &str) -> bool {
560        self.get(name).is_some() || name.to_uppercase() == "COUNT" // COUNT(*) special case
561    }
562}
563
564impl Default for AggregateRegistry {
565    fn default() -> Self {
566        Self::new()
567    }
568}
569
570/// Check if an expression contains aggregate functions
571pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
572    use crate::recursive_parser::SqlExpression;
573
574    match expr {
575        SqlExpression::FunctionCall { name, args, .. } => {
576            let registry = AggregateRegistry::new();
577            if registry.is_aggregate(name) {
578                return true;
579            }
580            // Check nested expressions
581            args.iter().any(contains_aggregate)
582        }
583        SqlExpression::BinaryOp { left, right, .. } => {
584            contains_aggregate(left) || contains_aggregate(right)
585        }
586        SqlExpression::Not { expr } => contains_aggregate(expr),
587        SqlExpression::CaseExpression {
588            when_branches,
589            else_branch,
590        } => {
591            when_branches.iter().any(|branch| {
592                contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
593            }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
594        }
595        _ => false,
596    }
597}
598
599/// Check if an expression is a constant (string literal, number literal, boolean, null)
600/// Constants are compatible with aggregate queries and should produce a single row
601pub fn is_constant_expression(expr: &crate::recursive_parser::SqlExpression) -> bool {
602    use crate::recursive_parser::SqlExpression;
603
604    match expr {
605        SqlExpression::StringLiteral(_) => true,
606        SqlExpression::NumberLiteral(_) => true,
607        SqlExpression::BooleanLiteral(_) => true,
608        SqlExpression::Null => true,
609        SqlExpression::DateTimeConstructor { .. } => true,
610        SqlExpression::DateTimeToday { .. } => true,
611        // Binary operations between constants are also constant
612        SqlExpression::BinaryOp { left, right, .. } => {
613            is_constant_expression(left) && is_constant_expression(right)
614        }
615        // NOT of a constant is still constant
616        SqlExpression::Not { expr } => is_constant_expression(expr),
617        // Case expressions with constant conditions and results are constant
618        SqlExpression::CaseExpression {
619            when_branches,
620            else_branch,
621        } => {
622            when_branches.iter().all(|branch| {
623                is_constant_expression(&branch.condition) && is_constant_expression(&branch.result)
624            }) && else_branch
625                .as_ref()
626                .map_or(true, |e| is_constant_expression(e))
627        }
628        // Function calls that don't reference columns or aggregates are constant
629        // (like CONVERT(100, 'km', 'miles') or mathematical constants)
630        SqlExpression::FunctionCall { args, .. } => {
631            // Only if all arguments are constants and it's not an aggregate
632            !contains_aggregate(expr) && args.iter().all(is_constant_expression)
633        }
634        _ => false,
635    }
636}
637
638/// Check if an expression is aggregate-compatible (either an aggregate or a constant)
639/// This is used to determine if a SELECT list should produce a single row
640pub fn is_aggregate_compatible(expr: &crate::recursive_parser::SqlExpression) -> bool {
641    contains_aggregate(expr) || is_constant_expression(expr)
642}