1use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13#[derive(Debug, Clone)]
15pub enum AggregateState {
16    Count(i64),
17    Sum(SumState),
18    Avg(AvgState),
19    MinMax(MinMaxState),
20    Variance(VarianceState),
21    CollectList(Vec<DataValue>),
22    Percentile(PercentileState),
23    Mode(ModeState),
24    Analytics(analytics::AnalyticsState),
25    StringAgg(StringAggState),
26}
27
28#[derive(Debug, Clone)]
30pub struct SumState {
31    pub int_sum: Option<i64>,
32    pub float_sum: Option<f64>,
33    pub has_values: bool,
34}
35
36impl Default for SumState {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl SumState {
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            int_sum: None,
47            float_sum: None,
48            has_values: false,
49        }
50    }
51
52    pub fn add(&mut self, value: &DataValue) -> Result<()> {
53        match value {
54            DataValue::Null => Ok(()), DataValue::Integer(n) => {
56                self.has_values = true;
57                if let Some(ref mut sum) = self.int_sum {
58                    *sum = sum.saturating_add(*n);
59                } else if let Some(ref mut fsum) = self.float_sum {
60                    *fsum += *n as f64;
61                } else {
62                    self.int_sum = Some(*n);
63                }
64                Ok(())
65            }
66            DataValue::Float(f) => {
67                self.has_values = true;
68                if let Some(isum) = self.int_sum.take() {
70                    self.float_sum = Some(isum as f64 + f);
71                } else if let Some(ref mut fsum) = self.float_sum {
72                    *fsum += f;
73                } else {
74                    self.float_sum = Some(*f);
75                }
76                Ok(())
77            }
78            _ => Err(anyhow!("Cannot sum non-numeric value")),
79        }
80    }
81
82    #[must_use]
83    pub fn finalize(self) -> DataValue {
84        if !self.has_values {
85            return DataValue::Null;
86        }
87
88        if let Some(fsum) = self.float_sum {
89            DataValue::Float(fsum)
90        } else if let Some(isum) = self.int_sum {
91            DataValue::Integer(isum)
92        } else {
93            DataValue::Null
94        }
95    }
96}
97
98#[derive(Debug, Clone)]
100pub struct AvgState {
101    pub sum: SumState,
102    pub count: i64,
103}
104
105impl Default for AvgState {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl AvgState {
112    #[must_use]
113    pub fn new() -> Self {
114        Self {
115            sum: SumState::new(),
116            count: 0,
117        }
118    }
119
120    pub fn add(&mut self, value: &DataValue) -> Result<()> {
121        if !matches!(value, DataValue::Null) {
122            self.sum.add(value)?;
123            self.count += 1;
124        }
125        Ok(())
126    }
127
128    #[must_use]
129    pub fn finalize(self) -> DataValue {
130        if self.count == 0 {
131            return DataValue::Null;
132        }
133
134        let sum = self.sum.finalize();
135        match sum {
136            DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
137            DataValue::Float(f) => DataValue::Float(f / self.count as f64),
138            _ => DataValue::Null,
139        }
140    }
141}
142
143#[derive(Debug, Clone)]
145pub struct MinMaxState {
146    pub is_min: bool,
147    pub current: Option<DataValue>,
148}
149
150impl MinMaxState {
151    #[must_use]
152    pub fn new(is_min: bool) -> Self {
153        Self {
154            is_min,
155            current: None,
156        }
157    }
158
159    pub fn add(&mut self, value: &DataValue) -> Result<()> {
160        if matches!(value, DataValue::Null) {
161            return Ok(());
162        }
163
164        if let Some(ref current) = self.current {
165            let should_update = if self.is_min {
166                value < current
167            } else {
168                value > current
169            };
170
171            if should_update {
172                self.current = Some(value.clone());
173            }
174        } else {
175            self.current = Some(value.clone());
176        }
177
178        Ok(())
179    }
180
181    #[must_use]
182    pub fn finalize(self) -> DataValue {
183        self.current.unwrap_or(DataValue::Null)
184    }
185}
186
187#[derive(Debug, Clone)]
189pub struct VarianceState {
190    pub sum: f64,
191    pub sum_of_squares: f64,
192    pub count: i64,
193}
194
195impl Default for VarianceState {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201impl VarianceState {
202    #[must_use]
203    pub fn new() -> Self {
204        Self {
205            sum: 0.0,
206            sum_of_squares: 0.0,
207            count: 0,
208        }
209    }
210
211    pub fn add(&mut self, value: &DataValue) -> Result<()> {
212        match value {
213            DataValue::Null => Ok(()), DataValue::Integer(n) => {
215                let f = *n as f64;
216                self.sum += f;
217                self.sum_of_squares += f * f;
218                self.count += 1;
219                Ok(())
220            }
221            DataValue::Float(f) => {
222                self.sum += f;
223                self.sum_of_squares += f * f;
224                self.count += 1;
225                Ok(())
226            }
227            _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
228        }
229    }
230
231    #[must_use]
232    pub fn variance(&self) -> f64 {
233        if self.count <= 1 {
234            return 0.0;
235        }
236        let mean = self.sum / self.count as f64;
237        (self.sum_of_squares / self.count as f64) - (mean * mean)
238    }
239
240    #[must_use]
241    pub fn stddev(&self) -> f64 {
242        self.variance().sqrt()
243    }
244
245    #[must_use]
246    pub fn finalize_variance(self) -> DataValue {
247        if self.count == 0 {
248            DataValue::Null
249        } else {
250            DataValue::Float(self.variance())
251        }
252    }
253
254    #[must_use]
255    pub fn finalize_stddev(self) -> DataValue {
256        if self.count == 0 {
257            DataValue::Null
258        } else {
259            DataValue::Float(self.stddev())
260        }
261    }
262
263    #[must_use]
264    pub fn variance_sample(&self) -> f64 {
265        if self.count <= 1 {
266            return 0.0;
267        }
268        let mean = self.sum / self.count as f64;
269        let variance_n = (self.sum_of_squares / self.count as f64) - (mean * mean);
270        variance_n * (self.count as f64 / (self.count - 1) as f64)
272    }
273
274    #[must_use]
275    pub fn stddev_sample(&self) -> f64 {
276        self.variance_sample().sqrt()
277    }
278
279    #[must_use]
280    pub fn finalize_variance_sample(self) -> DataValue {
281        if self.count <= 1 {
282            DataValue::Null
283        } else {
284            DataValue::Float(self.variance_sample())
285        }
286    }
287
288    #[must_use]
289    pub fn finalize_stddev_sample(self) -> DataValue {
290        if self.count <= 1 {
291            DataValue::Null
292        } else {
293            DataValue::Float(self.stddev_sample())
294        }
295    }
296}
297
298#[derive(Debug, Clone)]
300pub struct PercentileState {
301    pub values: Vec<DataValue>,
302    pub percentile: f64,
303}
304
305impl Default for PercentileState {
306    fn default() -> Self {
307        Self::new(50.0) }
309}
310
311impl PercentileState {
312    #[must_use]
313    pub fn new(percentile: f64) -> Self {
314        Self {
315            values: Vec::new(),
316            percentile: percentile.clamp(0.0, 100.0),
317        }
318    }
319
320    pub fn add(&mut self, value: &DataValue) -> Result<()> {
321        if !matches!(value, DataValue::Null) {
322            self.values.push(value.clone());
323        }
324        Ok(())
325    }
326
327    #[must_use]
328    pub fn finalize(mut self) -> DataValue {
329        if self.values.is_empty() {
330            return DataValue::Null;
331        }
332
333        self.values.sort_by(|a, b| {
335            use std::cmp::Ordering;
336            match (a, b) {
337                (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
338                (DataValue::Float(a), DataValue::Float(b)) => {
339                    a.partial_cmp(b).unwrap_or(Ordering::Equal)
340                }
341                (DataValue::Integer(a), DataValue::Float(b)) => {
342                    (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
343                }
344                (DataValue::Float(a), DataValue::Integer(b)) => {
345                    a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
346                }
347                _ => Ordering::Equal,
348            }
349        });
350
351        let n = self.values.len();
352        if self.percentile == 0.0 {
353            return self.values[0].clone();
354        }
355        if self.percentile == 100.0 {
356            return self.values[n - 1].clone();
357        }
358
359        let pos = (self.percentile / 100.0) * ((n - 1) as f64);
361        let lower_idx = pos.floor() as usize;
362        let upper_idx = pos.ceil() as usize;
363
364        if lower_idx == upper_idx {
365            self.values[lower_idx].clone()
367        } else {
368            let fraction = pos - lower_idx as f64;
370            let lower_val = &self.values[lower_idx];
371            let upper_val = &self.values[upper_idx];
372
373            match (lower_val, upper_val) {
374                (DataValue::Integer(a), DataValue::Integer(b)) => {
375                    let result = *a as f64 + fraction * (*b - *a) as f64;
376                    if result.fract() == 0.0 {
377                        DataValue::Integer(result as i64)
378                    } else {
379                        DataValue::Float(result)
380                    }
381                }
382                (DataValue::Float(a), DataValue::Float(b)) => {
383                    DataValue::Float(a + fraction * (b - a))
384                }
385                (DataValue::Integer(a), DataValue::Float(b)) => {
386                    DataValue::Float(*a as f64 + fraction * (b - *a as f64))
387                }
388                (DataValue::Float(a), DataValue::Integer(b)) => {
389                    DataValue::Float(a + fraction * (*b as f64 - a))
390                }
391                _ => lower_val.clone(),
393            }
394        }
395    }
396}
397
398#[derive(Debug, Clone)]
400pub struct ModeState {
401    pub counts: std::collections::HashMap<String, (DataValue, i64)>,
402}
403
404impl Default for ModeState {
405    fn default() -> Self {
406        Self::new()
407    }
408}
409
410impl ModeState {
411    #[must_use]
412    pub fn new() -> Self {
413        Self {
414            counts: std::collections::HashMap::new(),
415        }
416    }
417
418    pub fn add(&mut self, value: &DataValue) -> Result<()> {
419        if matches!(value, DataValue::Null) {
420            return Ok(());
421        }
422
423        let key = match value {
425            DataValue::String(s) => s.clone(),
426            DataValue::InternedString(s) => s.to_string(),
427            DataValue::Integer(i) => i.to_string(),
428            DataValue::Float(f) => f.to_string(),
429            DataValue::Boolean(b) => b.to_string(),
430            DataValue::DateTime(dt) => dt.to_string(),
431            DataValue::Null => return Ok(()),
432        };
433
434        let entry = self.counts.entry(key).or_insert((value.clone(), 0));
436        entry.1 += 1;
437
438        Ok(())
439    }
440
441    #[must_use]
442    pub fn finalize(self) -> DataValue {
443        if self.counts.is_empty() {
444            return DataValue::Null;
445        }
446
447        let max_entry = self.counts.iter().max_by_key(|(_, (_, count))| count);
449
450        match max_entry {
451            Some((_, (value, _count))) => value.clone(),
452            None => DataValue::Null,
453        }
454    }
455}
456
457pub trait AggregateFunction: Send + Sync {
459    fn name(&self) -> &str;
461
462    fn init(&self) -> AggregateState;
464
465    fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
467
468    fn finalize(&self, state: AggregateState) -> DataValue;
470
471    fn requires_numeric(&self) -> bool {
473        false
474    }
475}
476
477#[derive(Debug, Clone)]
479pub struct StringAggState {
480    pub values: Vec<String>,
481    pub separator: String,
482}
483
484impl Default for StringAggState {
485    fn default() -> Self {
486        Self::new(",")
487    }
488}
489
490impl StringAggState {
491    #[must_use]
492    pub fn new(separator: &str) -> Self {
493        Self {
494            values: Vec::new(),
495            separator: separator.to_string(),
496        }
497    }
498
499    pub fn add(&mut self, value: &DataValue) -> Result<()> {
500        match value {
501            DataValue::Null => Ok(()), DataValue::String(s) => {
503                self.values.push(s.clone());
504                Ok(())
505            }
506            DataValue::InternedString(s) => {
507                self.values.push(s.to_string());
508                Ok(())
509            }
510            DataValue::Integer(n) => {
511                self.values.push(n.to_string());
512                Ok(())
513            }
514            DataValue::Float(f) => {
515                self.values.push(f.to_string());
516                Ok(())
517            }
518            DataValue::Boolean(b) => {
519                self.values.push(b.to_string());
520                Ok(())
521            }
522            DataValue::DateTime(dt) => {
523                self.values.push(dt.to_string());
524                Ok(())
525            }
526        }
527    }
528
529    #[must_use]
530    pub fn finalize(self) -> DataValue {
531        if self.values.is_empty() {
532            DataValue::Null
533        } else {
534            DataValue::String(self.values.join(&self.separator))
535        }
536    }
537}
538
539pub struct AggregateRegistry {
541    functions: Vec<Box<dyn AggregateFunction>>,
542}
543
544impl AggregateRegistry {
545    #[must_use]
546    pub fn new() -> Self {
547        use analytics::{
548            CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
549            RankFunction, SumsFunction,
550        };
551        use functions::{
552            AvgFunction, CountFunction, CountStarFunction, MaxFunction, MedianFunction,
553            MinFunction, ModeFunction, PercentileFunction, StdDevFunction, StdDevPopFunction,
554            StdDevSampFunction, StringAggFunction, VarPopFunction, VarSampFunction,
555            VarianceFunction,
556        };
557
558        let functions: Vec<Box<dyn AggregateFunction>> = vec![
559            Box::new(AvgFunction),
563            Box::new(MinFunction),
564            Box::new(MaxFunction),
565            Box::new(StdDevFunction),
566            Box::new(StdDevPopFunction),
567            Box::new(StdDevSampFunction),
568            Box::new(VarianceFunction),
569            Box::new(VarPopFunction),
570            Box::new(VarSampFunction),
571            Box::new(MedianFunction),
572            Box::new(ModeFunction),
573            Box::new(PercentileFunction),
574            Box::new(StringAggFunction),
575            Box::new(DeltasFunction),
577            Box::new(SumsFunction),
578            Box::new(MavgFunction),
579            Box::new(PctChangeFunction),
580            Box::new(RankFunction),
581            Box::new(CumMaxFunction),
582            Box::new(CumMinFunction),
583        ];
584
585        Self { functions }
586    }
587
588    #[must_use]
589    pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
590        let name_upper = name.to_uppercase();
591        self.functions
592            .iter()
593            .find(|f| f.name() == name_upper)
594            .map(std::convert::AsRef::as_ref)
595    }
596
597    #[must_use]
598    pub fn is_aggregate(&self, name: &str) -> bool {
599        use crate::sql::aggregate_functions::AggregateFunctionRegistry;
600
601        if self.get(name).is_some() {
603            return true;
604        }
605
606        let new_registry = AggregateFunctionRegistry::new();
608        new_registry.contains(name)
609    }
610}
611
612impl Default for AggregateRegistry {
613    fn default() -> Self {
614        Self::new()
615    }
616}
617
618pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
620    use crate::recursive_parser::SqlExpression;
621    use crate::sql::aggregate_functions::AggregateFunctionRegistry;
622
623    match expr {
624        SqlExpression::FunctionCall { name, args, .. } => {
625            let registry = AggregateRegistry::new();
627            if registry.is_aggregate(name) {
628                return true;
629            }
630            let new_registry = AggregateFunctionRegistry::new();
632            if new_registry.contains(name) {
633                return true;
634            }
635            args.iter().any(contains_aggregate)
637        }
638        SqlExpression::BinaryOp { left, right, .. } => {
639            contains_aggregate(left) || contains_aggregate(right)
640        }
641        SqlExpression::Not { expr } => contains_aggregate(expr),
642        SqlExpression::CaseExpression {
643            when_branches,
644            else_branch,
645        } => {
646            when_branches.iter().any(|branch| {
647                contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
648            }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
649        }
650        _ => false,
651    }
652}
653
654pub fn is_constant_expression(expr: &crate::recursive_parser::SqlExpression) -> bool {
657    use crate::recursive_parser::SqlExpression;
658
659    match expr {
660        SqlExpression::StringLiteral(_) => true,
661        SqlExpression::NumberLiteral(_) => true,
662        SqlExpression::BooleanLiteral(_) => true,
663        SqlExpression::Null => true,
664        SqlExpression::DateTimeConstructor { .. } => true,
665        SqlExpression::DateTimeToday { .. } => true,
666        SqlExpression::BinaryOp { left, right, .. } => {
668            is_constant_expression(left) && is_constant_expression(right)
669        }
670        SqlExpression::Not { expr } => is_constant_expression(expr),
672        SqlExpression::CaseExpression {
674            when_branches,
675            else_branch,
676        } => {
677            when_branches.iter().all(|branch| {
678                is_constant_expression(&branch.condition) && is_constant_expression(&branch.result)
679            }) && else_branch
680                .as_ref()
681                .map_or(true, |e| is_constant_expression(e))
682        }
683        SqlExpression::FunctionCall { args, .. } => {
686            !contains_aggregate(expr) && args.iter().all(is_constant_expression)
688        }
689        _ => false,
690    }
691}
692
693pub fn is_aggregate_compatible(expr: &crate::recursive_parser::SqlExpression) -> bool {
696    contains_aggregate(expr) || is_constant_expression(expr)
697}