Skip to main content

sql_cli/sql/aggregates/
mod.rs

1//! Aggregate functions for GROUP BY operations
2//!
3//! This module provides SQL aggregate functions like SUM, AVG, COUNT, MIN, MAX
4//! that work with the `DataView` partitioning system for efficient GROUP BY queries.
5
6use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13/// State maintained during aggregation
14#[derive(Debug, Clone)]
15pub enum AggregateState {
16    Count(i64),
17    Sum(SumState),
18    Avg(AvgState),
19    MinMax(MinMaxState),
20    Variance(VarianceState),
21    CollectList(Vec<DataValue>),
22    Percentile(PercentileState),
23    Mode(ModeState),
24    Analytics(analytics::AnalyticsState),
25    StringAgg(StringAggState),
26}
27
28/// State for SUM aggregation
29#[derive(Debug, Clone)]
30pub struct SumState {
31    pub int_sum: Option<i64>,
32    pub float_sum: Option<f64>,
33    pub has_values: bool,
34}
35
36impl Default for SumState {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl SumState {
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            int_sum: None,
47            float_sum: None,
48            has_values: false,
49        }
50    }
51
52    pub fn add(&mut self, value: &DataValue) -> Result<()> {
53        match value {
54            DataValue::Null => Ok(()), // Skip nulls
55            DataValue::Integer(n) => {
56                self.has_values = true;
57                if let Some(ref mut sum) = self.int_sum {
58                    *sum = sum.saturating_add(*n);
59                } else if let Some(ref mut fsum) = self.float_sum {
60                    *fsum += *n as f64;
61                } else {
62                    self.int_sum = Some(*n);
63                }
64                Ok(())
65            }
66            DataValue::Float(f) => {
67                self.has_values = true;
68                // Once we have a float, convert everything to float
69                if let Some(isum) = self.int_sum.take() {
70                    self.float_sum = Some(isum as f64 + f);
71                } else if let Some(ref mut fsum) = self.float_sum {
72                    *fsum += f;
73                } else {
74                    self.float_sum = Some(*f);
75                }
76                Ok(())
77            }
78            DataValue::Boolean(b) => {
79                // Coerce boolean to integer: true=1, false=0
80                // Enables patterns like AVG(x > 5), SUM(col = 'value')
81                let n = if *b { 1i64 } else { 0i64 };
82                self.has_values = true;
83                if let Some(ref mut sum) = self.int_sum {
84                    *sum = sum.saturating_add(n);
85                } else if let Some(ref mut fsum) = self.float_sum {
86                    *fsum += n as f64;
87                } else {
88                    self.int_sum = Some(n);
89                }
90                Ok(())
91            }
92            _ => Err(anyhow!("Cannot sum non-numeric value")),
93        }
94    }
95
96    #[must_use]
97    pub fn finalize(self) -> DataValue {
98        if !self.has_values {
99            return DataValue::Null;
100        }
101
102        if let Some(fsum) = self.float_sum {
103            DataValue::Float(fsum)
104        } else if let Some(isum) = self.int_sum {
105            DataValue::Integer(isum)
106        } else {
107            DataValue::Null
108        }
109    }
110}
111
112/// State for AVG aggregation
113#[derive(Debug, Clone)]
114pub struct AvgState {
115    pub sum: SumState,
116    pub count: i64,
117}
118
119impl Default for AvgState {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl AvgState {
126    #[must_use]
127    pub fn new() -> Self {
128        Self {
129            sum: SumState::new(),
130            count: 0,
131        }
132    }
133
134    pub fn add(&mut self, value: &DataValue) -> Result<()> {
135        if !matches!(value, DataValue::Null) {
136            self.sum.add(value)?;
137            self.count += 1;
138        }
139        Ok(())
140    }
141
142    #[must_use]
143    pub fn finalize(self) -> DataValue {
144        if self.count == 0 {
145            return DataValue::Null;
146        }
147
148        let sum = self.sum.finalize();
149        match sum {
150            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
151            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
152            _ => DataValue::Null,
153        }
154    }
155}
156
157/// State for MIN/MAX aggregation
158#[derive(Debug, Clone)]
159pub struct MinMaxState {
160    pub is_min: bool,
161    pub current: Option<DataValue>,
162}
163
164impl MinMaxState {
165    #[must_use]
166    pub fn new(is_min: bool) -> Self {
167        Self {
168            is_min,
169            current: None,
170        }
171    }
172
173    pub fn add(&mut self, value: &DataValue) -> Result<()> {
174        if matches!(value, DataValue::Null) {
175            return Ok(());
176        }
177
178        if let Some(ref current) = self.current {
179            let should_update = if self.is_min {
180                value < current
181            } else {
182                value > current
183            };
184
185            if should_update {
186                self.current = Some(value.clone());
187            }
188        } else {
189            self.current = Some(value.clone());
190        }
191
192        Ok(())
193    }
194
195    #[must_use]
196    pub fn finalize(self) -> DataValue {
197        self.current.unwrap_or(DataValue::Null)
198    }
199}
200
201/// State for VARIANCE/STDDEV aggregation
202#[derive(Debug, Clone)]
203pub struct VarianceState {
204    pub sum: f64,
205    pub sum_of_squares: f64,
206    pub count: i64,
207}
208
209impl Default for VarianceState {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215impl VarianceState {
216    #[must_use]
217    pub fn new() -> Self {
218        Self {
219            sum: 0.0,
220            sum_of_squares: 0.0,
221            count: 0,
222        }
223    }
224
225    pub fn add(&mut self, value: &DataValue) -> Result<()> {
226        match value {
227            DataValue::Null => Ok(()), // Skip nulls
228            DataValue::Integer(n) => {
229                let f = *n as f64;
230                self.sum += f;
231                self.sum_of_squares += f * f;
232                self.count += 1;
233                Ok(())
234            }
235            DataValue::Float(f) => {
236                self.sum += f;
237                self.sum_of_squares += f * f;
238                self.count += 1;
239                Ok(())
240            }
241            _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
242        }
243    }
244
245    #[must_use]
246    pub fn variance(&self) -> f64 {
247        if self.count <= 1 {
248            return 0.0;
249        }
250        let mean = self.sum / self.count as f64;
251        (self.sum_of_squares / self.count as f64) - (mean * mean)
252    }
253
254    #[must_use]
255    pub fn stddev(&self) -> f64 {
256        self.variance().sqrt()
257    }
258
259    #[must_use]
260    pub fn finalize_variance(self) -> DataValue {
261        if self.count == 0 {
262            DataValue::Null
263        } else {
264            DataValue::Float(self.variance())
265        }
266    }
267
268    #[must_use]
269    pub fn finalize_stddev(self) -> DataValue {
270        if self.count == 0 {
271            DataValue::Null
272        } else {
273            DataValue::Float(self.stddev())
274        }
275    }
276
277    #[must_use]
278    pub fn variance_sample(&self) -> f64 {
279        if self.count <= 1 {
280            return 0.0;
281        }
282        let mean = self.sum / self.count as f64;
283        let variance_n = (self.sum_of_squares / self.count as f64) - (mean * mean);
284        // Convert from population variance to sample variance
285        variance_n * (self.count as f64 / (self.count - 1) as f64)
286    }
287
288    #[must_use]
289    pub fn stddev_sample(&self) -> f64 {
290        self.variance_sample().sqrt()
291    }
292
293    #[must_use]
294    pub fn finalize_variance_sample(self) -> DataValue {
295        if self.count <= 1 {
296            DataValue::Null
297        } else {
298            DataValue::Float(self.variance_sample())
299        }
300    }
301
302    #[must_use]
303    pub fn finalize_stddev_sample(self) -> DataValue {
304        if self.count <= 1 {
305            DataValue::Null
306        } else {
307            DataValue::Float(self.stddev_sample())
308        }
309    }
310}
311
312/// State for PERCENTILE aggregation
313#[derive(Debug, Clone)]
314pub struct PercentileState {
315    pub values: Vec<DataValue>,
316    pub percentile: f64,
317}
318
319impl Default for PercentileState {
320    fn default() -> Self {
321        Self::new(50.0) // Default to median
322    }
323}
324
325impl PercentileState {
326    #[must_use]
327    pub fn new(percentile: f64) -> Self {
328        Self {
329            values: Vec::new(),
330            percentile: percentile.clamp(0.0, 100.0),
331        }
332    }
333
334    pub fn add(&mut self, value: &DataValue) -> Result<()> {
335        if !matches!(value, DataValue::Null) {
336            self.values.push(value.clone());
337        }
338        Ok(())
339    }
340
341    #[must_use]
342    pub fn finalize(mut self) -> DataValue {
343        if self.values.is_empty() {
344            return DataValue::Null;
345        }
346
347        // Sort values for percentile calculation
348        self.values.sort_by(|a, b| {
349            use std::cmp::Ordering;
350            match (a, b) {
351                (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
352                (DataValue::Float(a), DataValue::Float(b)) => {
353                    a.partial_cmp(b).unwrap_or(Ordering::Equal)
354                }
355                (DataValue::Integer(a), DataValue::Float(b)) => {
356                    (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
357                }
358                (DataValue::Float(a), DataValue::Integer(b)) => {
359                    a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
360                }
361                _ => Ordering::Equal,
362            }
363        });
364
365        let n = self.values.len();
366        if self.percentile == 0.0 {
367            return self.values[0].clone();
368        }
369        if self.percentile == 100.0 {
370            return self.values[n - 1].clone();
371        }
372
373        // Calculate percentile using linear interpolation
374        let pos = (self.percentile / 100.0) * ((n - 1) as f64);
375        let lower_idx = pos.floor() as usize;
376        let upper_idx = pos.ceil() as usize;
377
378        if lower_idx == upper_idx {
379            // Exact position
380            self.values[lower_idx].clone()
381        } else {
382            // Interpolate between two values
383            let fraction = pos - lower_idx as f64;
384            let lower_val = &self.values[lower_idx];
385            let upper_val = &self.values[upper_idx];
386
387            match (lower_val, upper_val) {
388                (DataValue::Integer(a), DataValue::Integer(b)) => {
389                    let result = *a as f64 + fraction * (*b - *a) as f64;
390                    if result.fract() == 0.0 {
391                        DataValue::Integer(result as i64)
392                    } else {
393                        DataValue::Float(result)
394                    }
395                }
396                (DataValue::Float(a), DataValue::Float(b)) => {
397                    DataValue::Float(a + fraction * (b - a))
398                }
399                (DataValue::Integer(a), DataValue::Float(b)) => {
400                    DataValue::Float(*a as f64 + fraction * (b - *a as f64))
401                }
402                (DataValue::Float(a), DataValue::Integer(b)) => {
403                    DataValue::Float(a + fraction * (*b as f64 - a))
404                }
405                // For non-numeric, return the lower value
406                _ => lower_val.clone(),
407            }
408        }
409    }
410}
411
412/// State for MODE aggregation (most frequent value)
413#[derive(Debug, Clone)]
414pub struct ModeState {
415    pub counts: std::collections::HashMap<String, (DataValue, i64)>,
416}
417
418impl Default for ModeState {
419    fn default() -> Self {
420        Self::new()
421    }
422}
423
424impl ModeState {
425    #[must_use]
426    pub fn new() -> Self {
427        Self {
428            counts: std::collections::HashMap::new(),
429        }
430    }
431
432    pub fn add(&mut self, value: &DataValue) -> Result<()> {
433        if matches!(value, DataValue::Null) {
434            return Ok(());
435        }
436
437        // Convert value to string for hashing, but keep original value for result
438        let key = match value {
439            DataValue::String(s) => s.clone(),
440            DataValue::InternedString(s) => s.to_string(),
441            DataValue::Integer(i) => i.to_string(),
442            DataValue::Float(f) => f.to_string(),
443            DataValue::Boolean(b) => b.to_string(),
444            DataValue::DateTime(dt) => dt.to_string(),
445            DataValue::Vector(v) => {
446                let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
447                format!("[{}]", components.join(","))
448            }
449            DataValue::Null => return Ok(()),
450        };
451
452        // Update count and store the original value
453        let entry = self.counts.entry(key).or_insert((value.clone(), 0));
454        entry.1 += 1;
455
456        Ok(())
457    }
458
459    #[must_use]
460    pub fn finalize(self) -> DataValue {
461        if self.counts.is_empty() {
462            return DataValue::Null;
463        }
464
465        // Find the value with the highest count
466        let max_entry = self.counts.iter().max_by_key(|(_, (_, count))| count);
467
468        match max_entry {
469            Some((_, (value, _count))) => value.clone(),
470            None => DataValue::Null,
471        }
472    }
473}
474
475/// Trait for all aggregate functions
476pub trait AggregateFunction: Send + Sync {
477    /// Name of the function (e.g., "SUM", "AVG")
478    fn name(&self) -> &str;
479
480    /// Initialize the aggregation state
481    fn init(&self) -> AggregateState;
482
483    /// Add a value to the aggregation
484    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
485
486    /// Finalize and return the result
487    fn finalize(&self, state: AggregateState) -> DataValue;
488
489    /// Check if this function requires numeric input
490    fn requires_numeric(&self) -> bool {
491        false
492    }
493}
494
495/// State for STRING_AGG aggregation
496#[derive(Debug, Clone)]
497pub struct StringAggState {
498    pub values: Vec<String>,
499    pub separator: String,
500}
501
502impl Default for StringAggState {
503    fn default() -> Self {
504        Self::new(",")
505    }
506}
507
508impl StringAggState {
509    #[must_use]
510    pub fn new(separator: &str) -> Self {
511        Self {
512            values: Vec::new(),
513            separator: separator.to_string(),
514        }
515    }
516
517    pub fn add(&mut self, value: &DataValue) -> Result<()> {
518        match value {
519            DataValue::Null => Ok(()), // Skip nulls
520            DataValue::String(s) => {
521                self.values.push(s.clone());
522                Ok(())
523            }
524            DataValue::InternedString(s) => {
525                self.values.push(s.to_string());
526                Ok(())
527            }
528            DataValue::Integer(n) => {
529                self.values.push(n.to_string());
530                Ok(())
531            }
532            DataValue::Float(f) => {
533                self.values.push(f.to_string());
534                Ok(())
535            }
536            DataValue::Boolean(b) => {
537                self.values.push(b.to_string());
538                Ok(())
539            }
540            DataValue::DateTime(dt) => {
541                self.values.push(dt.to_string());
542                Ok(())
543            }
544            DataValue::Vector(v) => {
545                let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
546                self.values.push(format!("[{}]", components.join(",")));
547                Ok(())
548            }
549        }
550    }
551
552    #[must_use]
553    pub fn finalize(self) -> DataValue {
554        if self.values.is_empty() {
555            DataValue::Null
556        } else {
557            DataValue::String(self.values.join(&self.separator))
558        }
559    }
560}
561
562/// Registry of aggregate functions
563pub struct AggregateRegistry {
564    functions: Vec<Box<dyn AggregateFunction>>,
565}
566
567impl AggregateRegistry {
568    #[must_use]
569    pub fn new() -> Self {
570        use analytics::{
571            CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
572            RankFunction, SumsFunction,
573        };
574        use functions::{
575            AvgFunction, MaxFunction, MedianFunction, MinFunction, ModeFunction,
576            PercentileFunction, StdDevFunction, StdDevPopFunction, StdDevSampFunction,
577            StringAggFunction, VarPopFunction, VarSampFunction, VarianceFunction,
578        };
579
580        let functions: Vec<Box<dyn AggregateFunction>> = vec![
581            // Box::new(CountFunction), // MIGRATED to new registry
582            // Box::new(CountStarFunction), // MIGRATED to new registry
583            // Box::new(SumFunction), // MIGRATED to new registry
584            Box::new(AvgFunction),
585            Box::new(MinFunction),
586            Box::new(MaxFunction),
587            Box::new(StdDevFunction),
588            Box::new(StdDevPopFunction),
589            Box::new(StdDevSampFunction),
590            Box::new(VarianceFunction),
591            Box::new(VarPopFunction),
592            Box::new(VarSampFunction),
593            Box::new(MedianFunction),
594            Box::new(ModeFunction),
595            Box::new(PercentileFunction),
596            Box::new(StringAggFunction),
597            // Analytics functions
598            Box::new(DeltasFunction),
599            Box::new(SumsFunction),
600            Box::new(MavgFunction),
601            Box::new(PctChangeFunction),
602            Box::new(RankFunction),
603            Box::new(CumMaxFunction),
604            Box::new(CumMinFunction),
605        ];
606
607        Self { functions }
608    }
609
610    #[must_use]
611    pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
612        let name_upper = name.to_uppercase();
613        self.functions
614            .iter()
615            .find(|f| f.name() == name_upper)
616            .map(std::convert::AsRef::as_ref)
617    }
618
619    #[must_use]
620    pub fn is_aggregate(&self, name: &str) -> bool {
621        use crate::sql::aggregate_functions::AggregateFunctionRegistry;
622
623        // Check this registry
624        if self.get(name).is_some() {
625            return true;
626        }
627
628        // Also check new registry for migrated functions (including COUNT, COUNT_STAR, SUM)
629        let new_registry = AggregateFunctionRegistry::new();
630        new_registry.contains(name)
631    }
632}
633
634impl Default for AggregateRegistry {
635    fn default() -> Self {
636        Self::new()
637    }
638}
639
640/// Check if an expression contains aggregate functions
641pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
642    use crate::recursive_parser::SqlExpression;
643    use crate::sql::aggregate_functions::AggregateFunctionRegistry;
644
645    match expr {
646        SqlExpression::FunctionCall { name, args, .. } => {
647            // Check old registry
648            let registry = AggregateRegistry::new();
649            if registry.is_aggregate(name) {
650                return true;
651            }
652            // Check new registry for migrated functions
653            let new_registry = AggregateFunctionRegistry::new();
654            if new_registry.contains(name) {
655                return true;
656            }
657            // Check nested expressions
658            args.iter().any(contains_aggregate)
659        }
660        SqlExpression::BinaryOp { left, right, .. } => {
661            contains_aggregate(left) || contains_aggregate(right)
662        }
663        SqlExpression::Not { expr } => contains_aggregate(expr),
664        SqlExpression::CaseExpression {
665            when_branches,
666            else_branch,
667        } => {
668            when_branches.iter().any(|branch| {
669                contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
670            }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
671        }
672        _ => false,
673    }
674}
675
676/// Check if an expression is a constant (string literal, number literal, boolean, null)
677/// Constants are compatible with aggregate queries and should produce a single row
678pub fn is_constant_expression(expr: &crate::recursive_parser::SqlExpression) -> bool {
679    use crate::recursive_parser::SqlExpression;
680
681    match expr {
682        SqlExpression::StringLiteral(_) => true,
683        SqlExpression::NumberLiteral(_) => true,
684        SqlExpression::BooleanLiteral(_) => true,
685        SqlExpression::Null => true,
686        SqlExpression::DateTimeConstructor { .. } => true,
687        SqlExpression::DateTimeToday { .. } => true,
688        // Binary operations between constants are also constant
689        SqlExpression::BinaryOp { left, right, .. } => {
690            is_constant_expression(left) && is_constant_expression(right)
691        }
692        // NOT of a constant is still constant
693        SqlExpression::Not { expr } => is_constant_expression(expr),
694        // Case expressions with constant conditions and results are constant
695        SqlExpression::CaseExpression {
696            when_branches,
697            else_branch,
698        } => {
699            when_branches.iter().all(|branch| {
700                is_constant_expression(&branch.condition) && is_constant_expression(&branch.result)
701            }) && else_branch
702                .as_ref()
703                .map_or(true, |e| is_constant_expression(e))
704        }
705        // Function calls that don't reference columns or aggregates are constant
706        // (like CONVERT(100, 'km', 'miles') or mathematical constants)
707        SqlExpression::FunctionCall { args, .. } => {
708            // Only if all arguments are constants and it's not an aggregate
709            !contains_aggregate(expr) && args.iter().all(is_constant_expression)
710        }
711        _ => false,
712    }
713}
714
715/// Check if an expression is aggregate-compatible (either an aggregate or a constant)
716/// This is used to determine if a SELECT list should produce a single row
717pub fn is_aggregate_compatible(expr: &crate::recursive_parser::SqlExpression) -> bool {
718    contains_aggregate(expr) || is_constant_expression(expr)
719}