sql_cli/sql/aggregates/
mod.rs

1//! Aggregate functions for GROUP BY operations
2//!
3//! This module provides SQL aggregate functions like SUM, AVG, COUNT, MIN, MAX
4//! that work with the `DataView` partitioning system for efficient GROUP BY queries.
5
6use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13/// State maintained during aggregation
14#[derive(Debug, Clone)]
15pub enum AggregateState {
16    Count(i64),
17    Sum(SumState),
18    Avg(AvgState),
19    MinMax(MinMaxState),
20    Variance(VarianceState),
21    CollectList(Vec<DataValue>),
22    Percentile(PercentileState),
23    Mode(ModeState),
24    Analytics(analytics::AnalyticsState),
25    StringAgg(StringAggState),
26}
27
28/// State for SUM aggregation
29#[derive(Debug, Clone)]
30pub struct SumState {
31    pub int_sum: Option<i64>,
32    pub float_sum: Option<f64>,
33    pub has_values: bool,
34}
35
36impl Default for SumState {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl SumState {
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            int_sum: None,
47            float_sum: None,
48            has_values: false,
49        }
50    }
51
52    pub fn add(&mut self, value: &DataValue) -> Result<()> {
53        match value {
54            DataValue::Null => Ok(()), // Skip nulls
55            DataValue::Integer(n) => {
56                self.has_values = true;
57                if let Some(ref mut sum) = self.int_sum {
58                    *sum = sum.saturating_add(*n);
59                } else if let Some(ref mut fsum) = self.float_sum {
60                    *fsum += *n as f64;
61                } else {
62                    self.int_sum = Some(*n);
63                }
64                Ok(())
65            }
66            DataValue::Float(f) => {
67                self.has_values = true;
68                // Once we have a float, convert everything to float
69                if let Some(isum) = self.int_sum.take() {
70                    self.float_sum = Some(isum as f64 + f);
71                } else if let Some(ref mut fsum) = self.float_sum {
72                    *fsum += f;
73                } else {
74                    self.float_sum = Some(*f);
75                }
76                Ok(())
77            }
78            _ => Err(anyhow!("Cannot sum non-numeric value")),
79        }
80    }
81
82    #[must_use]
83    pub fn finalize(self) -> DataValue {
84        if !self.has_values {
85            return DataValue::Null;
86        }
87
88        if let Some(fsum) = self.float_sum {
89            DataValue::Float(fsum)
90        } else if let Some(isum) = self.int_sum {
91            DataValue::Integer(isum)
92        } else {
93            DataValue::Null
94        }
95    }
96}
97
98/// State for AVG aggregation
99#[derive(Debug, Clone)]
100pub struct AvgState {
101    pub sum: SumState,
102    pub count: i64,
103}
104
105impl Default for AvgState {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl AvgState {
112    #[must_use]
113    pub fn new() -> Self {
114        Self {
115            sum: SumState::new(),
116            count: 0,
117        }
118    }
119
120    pub fn add(&mut self, value: &DataValue) -> Result<()> {
121        if !matches!(value, DataValue::Null) {
122            self.sum.add(value)?;
123            self.count += 1;
124        }
125        Ok(())
126    }
127
128    #[must_use]
129    pub fn finalize(self) -> DataValue {
130        if self.count == 0 {
131            return DataValue::Null;
132        }
133
134        let sum = self.sum.finalize();
135        match sum {
136            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
137            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
138            _ => DataValue::Null,
139        }
140    }
141}
142
143/// State for MIN/MAX aggregation
144#[derive(Debug, Clone)]
145pub struct MinMaxState {
146    pub is_min: bool,
147    pub current: Option<DataValue>,
148}
149
150impl MinMaxState {
151    #[must_use]
152    pub fn new(is_min: bool) -> Self {
153        Self {
154            is_min,
155            current: None,
156        }
157    }
158
159    pub fn add(&mut self, value: &DataValue) -> Result<()> {
160        if matches!(value, DataValue::Null) {
161            return Ok(());
162        }
163
164        if let Some(ref current) = self.current {
165            let should_update = if self.is_min {
166                value < current
167            } else {
168                value > current
169            };
170
171            if should_update {
172                self.current = Some(value.clone());
173            }
174        } else {
175            self.current = Some(value.clone());
176        }
177
178        Ok(())
179    }
180
181    #[must_use]
182    pub fn finalize(self) -> DataValue {
183        self.current.unwrap_or(DataValue::Null)
184    }
185}
186
187/// State for VARIANCE/STDDEV aggregation
188#[derive(Debug, Clone)]
189pub struct VarianceState {
190    pub sum: f64,
191    pub sum_of_squares: f64,
192    pub count: i64,
193}
194
195impl Default for VarianceState {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201impl VarianceState {
202    #[must_use]
203    pub fn new() -> Self {
204        Self {
205            sum: 0.0,
206            sum_of_squares: 0.0,
207            count: 0,
208        }
209    }
210
211    pub fn add(&mut self, value: &DataValue) -> Result<()> {
212        match value {
213            DataValue::Null => Ok(()), // Skip nulls
214            DataValue::Integer(n) => {
215                let f = *n as f64;
216                self.sum += f;
217                self.sum_of_squares += f * f;
218                self.count += 1;
219                Ok(())
220            }
221            DataValue::Float(f) => {
222                self.sum += f;
223                self.sum_of_squares += f * f;
224                self.count += 1;
225                Ok(())
226            }
227            _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
228        }
229    }
230
231    #[must_use]
232    pub fn variance(&self) -> f64 {
233        if self.count <= 1 {
234            return 0.0;
235        }
236        let mean = self.sum / self.count as f64;
237        (self.sum_of_squares / self.count as f64) - (mean * mean)
238    }
239
240    #[must_use]
241    pub fn stddev(&self) -> f64 {
242        self.variance().sqrt()
243    }
244
245    #[must_use]
246    pub fn finalize_variance(self) -> DataValue {
247        if self.count == 0 {
248            DataValue::Null
249        } else {
250            DataValue::Float(self.variance())
251        }
252    }
253
254    #[must_use]
255    pub fn finalize_stddev(self) -> DataValue {
256        if self.count == 0 {
257            DataValue::Null
258        } else {
259            DataValue::Float(self.stddev())
260        }
261    }
262
263    #[must_use]
264    pub fn variance_sample(&self) -> f64 {
265        if self.count <= 1 {
266            return 0.0;
267        }
268        let mean = self.sum / self.count as f64;
269        let variance_n = (self.sum_of_squares / self.count as f64) - (mean * mean);
270        // Convert from population variance to sample variance
271        variance_n * (self.count as f64 / (self.count - 1) as f64)
272    }
273
274    #[must_use]
275    pub fn stddev_sample(&self) -> f64 {
276        self.variance_sample().sqrt()
277    }
278
279    #[must_use]
280    pub fn finalize_variance_sample(self) -> DataValue {
281        if self.count <= 1 {
282            DataValue::Null
283        } else {
284            DataValue::Float(self.variance_sample())
285        }
286    }
287
288    #[must_use]
289    pub fn finalize_stddev_sample(self) -> DataValue {
290        if self.count <= 1 {
291            DataValue::Null
292        } else {
293            DataValue::Float(self.stddev_sample())
294        }
295    }
296}
297
298/// State for PERCENTILE aggregation
299#[derive(Debug, Clone)]
300pub struct PercentileState {
301    pub values: Vec<DataValue>,
302    pub percentile: f64,
303}
304
305impl Default for PercentileState {
306    fn default() -> Self {
307        Self::new(50.0) // Default to median
308    }
309}
310
311impl PercentileState {
312    #[must_use]
313    pub fn new(percentile: f64) -> Self {
314        Self {
315            values: Vec::new(),
316            percentile: percentile.clamp(0.0, 100.0),
317        }
318    }
319
320    pub fn add(&mut self, value: &DataValue) -> Result<()> {
321        if !matches!(value, DataValue::Null) {
322            self.values.push(value.clone());
323        }
324        Ok(())
325    }
326
327    #[must_use]
328    pub fn finalize(mut self) -> DataValue {
329        if self.values.is_empty() {
330            return DataValue::Null;
331        }
332
333        // Sort values for percentile calculation
334        self.values.sort_by(|a, b| {
335            use std::cmp::Ordering;
336            match (a, b) {
337                (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
338                (DataValue::Float(a), DataValue::Float(b)) => {
339                    a.partial_cmp(b).unwrap_or(Ordering::Equal)
340                }
341                (DataValue::Integer(a), DataValue::Float(b)) => {
342                    (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
343                }
344                (DataValue::Float(a), DataValue::Integer(b)) => {
345                    a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
346                }
347                _ => Ordering::Equal,
348            }
349        });
350
351        let n = self.values.len();
352        if self.percentile == 0.0 {
353            return self.values[0].clone();
354        }
355        if self.percentile == 100.0 {
356            return self.values[n - 1].clone();
357        }
358
359        // Calculate percentile using linear interpolation
360        let pos = (self.percentile / 100.0) * ((n - 1) as f64);
361        let lower_idx = pos.floor() as usize;
362        let upper_idx = pos.ceil() as usize;
363
364        if lower_idx == upper_idx {
365            // Exact position
366            self.values[lower_idx].clone()
367        } else {
368            // Interpolate between two values
369            let fraction = pos - lower_idx as f64;
370            let lower_val = &self.values[lower_idx];
371            let upper_val = &self.values[upper_idx];
372
373            match (lower_val, upper_val) {
374                (DataValue::Integer(a), DataValue::Integer(b)) => {
375                    let result = *a as f64 + fraction * (*b - *a) as f64;
376                    if result.fract() == 0.0 {
377                        DataValue::Integer(result as i64)
378                    } else {
379                        DataValue::Float(result)
380                    }
381                }
382                (DataValue::Float(a), DataValue::Float(b)) => {
383                    DataValue::Float(a + fraction * (b - a))
384                }
385                (DataValue::Integer(a), DataValue::Float(b)) => {
386                    DataValue::Float(*a as f64 + fraction * (b - *a as f64))
387                }
388                (DataValue::Float(a), DataValue::Integer(b)) => {
389                    DataValue::Float(a + fraction * (*b as f64 - a))
390                }
391                // For non-numeric, return the lower value
392                _ => lower_val.clone(),
393            }
394        }
395    }
396}
397
398/// State for MODE aggregation (most frequent value)
399#[derive(Debug, Clone)]
400pub struct ModeState {
401    pub counts: std::collections::HashMap<String, (DataValue, i64)>,
402}
403
404impl Default for ModeState {
405    fn default() -> Self {
406        Self::new()
407    }
408}
409
410impl ModeState {
411    #[must_use]
412    pub fn new() -> Self {
413        Self {
414            counts: std::collections::HashMap::new(),
415        }
416    }
417
418    pub fn add(&mut self, value: &DataValue) -> Result<()> {
419        if matches!(value, DataValue::Null) {
420            return Ok(());
421        }
422
423        // Convert value to string for hashing, but keep original value for result
424        let key = match value {
425            DataValue::String(s) => s.clone(),
426            DataValue::InternedString(s) => s.to_string(),
427            DataValue::Integer(i) => i.to_string(),
428            DataValue::Float(f) => f.to_string(),
429            DataValue::Boolean(b) => b.to_string(),
430            DataValue::DateTime(dt) => dt.to_string(),
431            DataValue::Null => return Ok(()),
432        };
433
434        // Update count and store the original value
435        let entry = self.counts.entry(key).or_insert((value.clone(), 0));
436        entry.1 += 1;
437
438        Ok(())
439    }
440
441    #[must_use]
442    pub fn finalize(self) -> DataValue {
443        if self.counts.is_empty() {
444            return DataValue::Null;
445        }
446
447        // Find the value with the highest count
448        let max_entry = self.counts.iter().max_by_key(|(_, (_, count))| count);
449
450        match max_entry {
451            Some((_, (value, _count))) => value.clone(),
452            None => DataValue::Null,
453        }
454    }
455}
456
457/// Trait for all aggregate functions
458pub trait AggregateFunction: Send + Sync {
459    /// Name of the function (e.g., "SUM", "AVG")
460    fn name(&self) -> &str;
461
462    /// Initialize the aggregation state
463    fn init(&self) -> AggregateState;
464
465    /// Add a value to the aggregation
466    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
467
468    /// Finalize and return the result
469    fn finalize(&self, state: AggregateState) -> DataValue;
470
471    /// Check if this function requires numeric input
472    fn requires_numeric(&self) -> bool {
473        false
474    }
475}
476
477/// State for STRING_AGG aggregation
478#[derive(Debug, Clone)]
479pub struct StringAggState {
480    pub values: Vec<String>,
481    pub separator: String,
482}
483
484impl Default for StringAggState {
485    fn default() -> Self {
486        Self::new(",")
487    }
488}
489
490impl StringAggState {
491    #[must_use]
492    pub fn new(separator: &str) -> Self {
493        Self {
494            values: Vec::new(),
495            separator: separator.to_string(),
496        }
497    }
498
499    pub fn add(&mut self, value: &DataValue) -> Result<()> {
500        match value {
501            DataValue::Null => Ok(()), // Skip nulls
502            DataValue::String(s) => {
503                self.values.push(s.clone());
504                Ok(())
505            }
506            DataValue::InternedString(s) => {
507                self.values.push(s.to_string());
508                Ok(())
509            }
510            DataValue::Integer(n) => {
511                self.values.push(n.to_string());
512                Ok(())
513            }
514            DataValue::Float(f) => {
515                self.values.push(f.to_string());
516                Ok(())
517            }
518            DataValue::Boolean(b) => {
519                self.values.push(b.to_string());
520                Ok(())
521            }
522            DataValue::DateTime(dt) => {
523                self.values.push(dt.to_string());
524                Ok(())
525            }
526        }
527    }
528
529    #[must_use]
530    pub fn finalize(self) -> DataValue {
531        if self.values.is_empty() {
532            DataValue::Null
533        } else {
534            DataValue::String(self.values.join(&self.separator))
535        }
536    }
537}
538
539/// Registry of aggregate functions
540pub struct AggregateRegistry {
541    functions: Vec<Box<dyn AggregateFunction>>,
542}
543
544impl AggregateRegistry {
545    #[must_use]
546    pub fn new() -> Self {
547        use analytics::{
548            CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
549            RankFunction, SumsFunction,
550        };
551        use functions::{
552            AvgFunction, MaxFunction, MedianFunction, MinFunction, ModeFunction,
553            PercentileFunction, StdDevFunction, StdDevPopFunction, StdDevSampFunction,
554            StringAggFunction, VarPopFunction, VarSampFunction, VarianceFunction,
555        };
556
557        let functions: Vec<Box<dyn AggregateFunction>> = vec![
558            // Box::new(CountFunction), // MIGRATED to new registry
559            // Box::new(CountStarFunction), // MIGRATED to new registry
560            // Box::new(SumFunction), // MIGRATED to new registry
561            Box::new(AvgFunction),
562            Box::new(MinFunction),
563            Box::new(MaxFunction),
564            Box::new(StdDevFunction),
565            Box::new(StdDevPopFunction),
566            Box::new(StdDevSampFunction),
567            Box::new(VarianceFunction),
568            Box::new(VarPopFunction),
569            Box::new(VarSampFunction),
570            Box::new(MedianFunction),
571            Box::new(ModeFunction),
572            Box::new(PercentileFunction),
573            Box::new(StringAggFunction),
574            // Analytics functions
575            Box::new(DeltasFunction),
576            Box::new(SumsFunction),
577            Box::new(MavgFunction),
578            Box::new(PctChangeFunction),
579            Box::new(RankFunction),
580            Box::new(CumMaxFunction),
581            Box::new(CumMinFunction),
582        ];
583
584        Self { functions }
585    }
586
587    #[must_use]
588    pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
589        let name_upper = name.to_uppercase();
590        self.functions
591            .iter()
592            .find(|f| f.name() == name_upper)
593            .map(std::convert::AsRef::as_ref)
594    }
595
596    #[must_use]
597    pub fn is_aggregate(&self, name: &str) -> bool {
598        use crate::sql::aggregate_functions::AggregateFunctionRegistry;
599
600        // Check this registry
601        if self.get(name).is_some() {
602            return true;
603        }
604
605        // Also check new registry for migrated functions (including COUNT, COUNT_STAR, SUM)
606        let new_registry = AggregateFunctionRegistry::new();
607        new_registry.contains(name)
608    }
609}
610
611impl Default for AggregateRegistry {
612    fn default() -> Self {
613        Self::new()
614    }
615}
616
617/// Check if an expression contains aggregate functions
618pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
619    use crate::recursive_parser::SqlExpression;
620    use crate::sql::aggregate_functions::AggregateFunctionRegistry;
621
622    match expr {
623        SqlExpression::FunctionCall { name, args, .. } => {
624            // Check old registry
625            let registry = AggregateRegistry::new();
626            if registry.is_aggregate(name) {
627                return true;
628            }
629            // Check new registry for migrated functions
630            let new_registry = AggregateFunctionRegistry::new();
631            if new_registry.contains(name) {
632                return true;
633            }
634            // Check nested expressions
635            args.iter().any(contains_aggregate)
636        }
637        SqlExpression::BinaryOp { left, right, .. } => {
638            contains_aggregate(left) || contains_aggregate(right)
639        }
640        SqlExpression::Not { expr } => contains_aggregate(expr),
641        SqlExpression::CaseExpression {
642            when_branches,
643            else_branch,
644        } => {
645            when_branches.iter().any(|branch| {
646                contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
647            }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
648        }
649        _ => false,
650    }
651}
652
653/// Check if an expression is a constant (string literal, number literal, boolean, null)
654/// Constants are compatible with aggregate queries and should produce a single row
655pub fn is_constant_expression(expr: &crate::recursive_parser::SqlExpression) -> bool {
656    use crate::recursive_parser::SqlExpression;
657
658    match expr {
659        SqlExpression::StringLiteral(_) => true,
660        SqlExpression::NumberLiteral(_) => true,
661        SqlExpression::BooleanLiteral(_) => true,
662        SqlExpression::Null => true,
663        SqlExpression::DateTimeConstructor { .. } => true,
664        SqlExpression::DateTimeToday { .. } => true,
665        // Binary operations between constants are also constant
666        SqlExpression::BinaryOp { left, right, .. } => {
667            is_constant_expression(left) && is_constant_expression(right)
668        }
669        // NOT of a constant is still constant
670        SqlExpression::Not { expr } => is_constant_expression(expr),
671        // Case expressions with constant conditions and results are constant
672        SqlExpression::CaseExpression {
673            when_branches,
674            else_branch,
675        } => {
676            when_branches.iter().all(|branch| {
677                is_constant_expression(&branch.condition) && is_constant_expression(&branch.result)
678            }) && else_branch
679                .as_ref()
680                .map_or(true, |e| is_constant_expression(e))
681        }
682        // Function calls that don't reference columns or aggregates are constant
683        // (like CONVERT(100, 'km', 'miles') or mathematical constants)
684        SqlExpression::FunctionCall { args, .. } => {
685            // Only if all arguments are constants and it's not an aggregate
686            !contains_aggregate(expr) && args.iter().all(is_constant_expression)
687        }
688        _ => false,
689    }
690}
691
692/// Check if an expression is aggregate-compatible (either an aggregate or a constant)
693/// This is used to determine if a SELECT list should produce a single row
694pub fn is_aggregate_compatible(expr: &crate::recursive_parser::SqlExpression) -> bool {
695    contains_aggregate(expr) || is_constant_expression(expr)
696}