1use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13#[derive(Debug, Clone)]
15pub enum AggregateState {
16 Count(i64),
17 Sum(SumState),
18 Avg(AvgState),
19 MinMax(MinMaxState),
20 Variance(VarianceState),
21 CollectList(Vec<DataValue>),
22 Percentile(PercentileState),
23 Mode(ModeState),
24 Analytics(analytics::AnalyticsState),
25 StringAgg(StringAggState),
26}
27
28#[derive(Debug, Clone)]
30pub struct SumState {
31 pub int_sum: Option<i64>,
32 pub float_sum: Option<f64>,
33 pub has_values: bool,
34}
35
36impl Default for SumState {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl SumState {
43 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 int_sum: None,
47 float_sum: None,
48 has_values: false,
49 }
50 }
51
52 pub fn add(&mut self, value: &DataValue) -> Result<()> {
53 match value {
54 DataValue::Null => Ok(()), DataValue::Integer(n) => {
56 self.has_values = true;
57 if let Some(ref mut sum) = self.int_sum {
58 *sum = sum.saturating_add(*n);
59 } else if let Some(ref mut fsum) = self.float_sum {
60 *fsum += *n as f64;
61 } else {
62 self.int_sum = Some(*n);
63 }
64 Ok(())
65 }
66 DataValue::Float(f) => {
67 self.has_values = true;
68 if let Some(isum) = self.int_sum.take() {
70 self.float_sum = Some(isum as f64 + f);
71 } else if let Some(ref mut fsum) = self.float_sum {
72 *fsum += f;
73 } else {
74 self.float_sum = Some(*f);
75 }
76 Ok(())
77 }
78 _ => Err(anyhow!("Cannot sum non-numeric value")),
79 }
80 }
81
82 #[must_use]
83 pub fn finalize(self) -> DataValue {
84 if !self.has_values {
85 return DataValue::Null;
86 }
87
88 if let Some(fsum) = self.float_sum {
89 DataValue::Float(fsum)
90 } else if let Some(isum) = self.int_sum {
91 DataValue::Integer(isum)
92 } else {
93 DataValue::Null
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
100pub struct AvgState {
101 pub sum: SumState,
102 pub count: i64,
103}
104
105impl Default for AvgState {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111impl AvgState {
112 #[must_use]
113 pub fn new() -> Self {
114 Self {
115 sum: SumState::new(),
116 count: 0,
117 }
118 }
119
120 pub fn add(&mut self, value: &DataValue) -> Result<()> {
121 if !matches!(value, DataValue::Null) {
122 self.sum.add(value)?;
123 self.count += 1;
124 }
125 Ok(())
126 }
127
128 #[must_use]
129 pub fn finalize(self) -> DataValue {
130 if self.count == 0 {
131 return DataValue::Null;
132 }
133
134 let sum = self.sum.finalize();
135 match sum {
136 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
137 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
138 _ => DataValue::Null,
139 }
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct MinMaxState {
146 pub is_min: bool,
147 pub current: Option<DataValue>,
148}
149
150impl MinMaxState {
151 #[must_use]
152 pub fn new(is_min: bool) -> Self {
153 Self {
154 is_min,
155 current: None,
156 }
157 }
158
159 pub fn add(&mut self, value: &DataValue) -> Result<()> {
160 if matches!(value, DataValue::Null) {
161 return Ok(());
162 }
163
164 if let Some(ref current) = self.current {
165 let should_update = if self.is_min {
166 value < current
167 } else {
168 value > current
169 };
170
171 if should_update {
172 self.current = Some(value.clone());
173 }
174 } else {
175 self.current = Some(value.clone());
176 }
177
178 Ok(())
179 }
180
181 #[must_use]
182 pub fn finalize(self) -> DataValue {
183 self.current.unwrap_or(DataValue::Null)
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct VarianceState {
190 pub sum: f64,
191 pub sum_of_squares: f64,
192 pub count: i64,
193}
194
195impl Default for VarianceState {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201impl VarianceState {
202 #[must_use]
203 pub fn new() -> Self {
204 Self {
205 sum: 0.0,
206 sum_of_squares: 0.0,
207 count: 0,
208 }
209 }
210
211 pub fn add(&mut self, value: &DataValue) -> Result<()> {
212 match value {
213 DataValue::Null => Ok(()), DataValue::Integer(n) => {
215 let f = *n as f64;
216 self.sum += f;
217 self.sum_of_squares += f * f;
218 self.count += 1;
219 Ok(())
220 }
221 DataValue::Float(f) => {
222 self.sum += f;
223 self.sum_of_squares += f * f;
224 self.count += 1;
225 Ok(())
226 }
227 _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
228 }
229 }
230
231 #[must_use]
232 pub fn variance(&self) -> f64 {
233 if self.count <= 1 {
234 return 0.0;
235 }
236 let mean = self.sum / self.count as f64;
237 (self.sum_of_squares / self.count as f64) - (mean * mean)
238 }
239
240 #[must_use]
241 pub fn stddev(&self) -> f64 {
242 self.variance().sqrt()
243 }
244
245 #[must_use]
246 pub fn finalize_variance(self) -> DataValue {
247 if self.count == 0 {
248 DataValue::Null
249 } else {
250 DataValue::Float(self.variance())
251 }
252 }
253
254 #[must_use]
255 pub fn finalize_stddev(self) -> DataValue {
256 if self.count == 0 {
257 DataValue::Null
258 } else {
259 DataValue::Float(self.stddev())
260 }
261 }
262
263 #[must_use]
264 pub fn variance_sample(&self) -> f64 {
265 if self.count <= 1 {
266 return 0.0;
267 }
268 let mean = self.sum / self.count as f64;
269 let variance_n = (self.sum_of_squares / self.count as f64) - (mean * mean);
270 variance_n * (self.count as f64 / (self.count - 1) as f64)
272 }
273
274 #[must_use]
275 pub fn stddev_sample(&self) -> f64 {
276 self.variance_sample().sqrt()
277 }
278
279 #[must_use]
280 pub fn finalize_variance_sample(self) -> DataValue {
281 if self.count <= 1 {
282 DataValue::Null
283 } else {
284 DataValue::Float(self.variance_sample())
285 }
286 }
287
288 #[must_use]
289 pub fn finalize_stddev_sample(self) -> DataValue {
290 if self.count <= 1 {
291 DataValue::Null
292 } else {
293 DataValue::Float(self.stddev_sample())
294 }
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct PercentileState {
301 pub values: Vec<DataValue>,
302 pub percentile: f64,
303}
304
305impl Default for PercentileState {
306 fn default() -> Self {
307 Self::new(50.0) }
309}
310
311impl PercentileState {
312 #[must_use]
313 pub fn new(percentile: f64) -> Self {
314 Self {
315 values: Vec::new(),
316 percentile: percentile.clamp(0.0, 100.0),
317 }
318 }
319
320 pub fn add(&mut self, value: &DataValue) -> Result<()> {
321 if !matches!(value, DataValue::Null) {
322 self.values.push(value.clone());
323 }
324 Ok(())
325 }
326
327 #[must_use]
328 pub fn finalize(mut self) -> DataValue {
329 if self.values.is_empty() {
330 return DataValue::Null;
331 }
332
333 self.values.sort_by(|a, b| {
335 use std::cmp::Ordering;
336 match (a, b) {
337 (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
338 (DataValue::Float(a), DataValue::Float(b)) => {
339 a.partial_cmp(b).unwrap_or(Ordering::Equal)
340 }
341 (DataValue::Integer(a), DataValue::Float(b)) => {
342 (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
343 }
344 (DataValue::Float(a), DataValue::Integer(b)) => {
345 a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
346 }
347 _ => Ordering::Equal,
348 }
349 });
350
351 let n = self.values.len();
352 if self.percentile == 0.0 {
353 return self.values[0].clone();
354 }
355 if self.percentile == 100.0 {
356 return self.values[n - 1].clone();
357 }
358
359 let pos = (self.percentile / 100.0) * ((n - 1) as f64);
361 let lower_idx = pos.floor() as usize;
362 let upper_idx = pos.ceil() as usize;
363
364 if lower_idx == upper_idx {
365 self.values[lower_idx].clone()
367 } else {
368 let fraction = pos - lower_idx as f64;
370 let lower_val = &self.values[lower_idx];
371 let upper_val = &self.values[upper_idx];
372
373 match (lower_val, upper_val) {
374 (DataValue::Integer(a), DataValue::Integer(b)) => {
375 let result = *a as f64 + fraction * (*b - *a) as f64;
376 if result.fract() == 0.0 {
377 DataValue::Integer(result as i64)
378 } else {
379 DataValue::Float(result)
380 }
381 }
382 (DataValue::Float(a), DataValue::Float(b)) => {
383 DataValue::Float(a + fraction * (b - a))
384 }
385 (DataValue::Integer(a), DataValue::Float(b)) => {
386 DataValue::Float(*a as f64 + fraction * (b - *a as f64))
387 }
388 (DataValue::Float(a), DataValue::Integer(b)) => {
389 DataValue::Float(a + fraction * (*b as f64 - a))
390 }
391 _ => lower_val.clone(),
393 }
394 }
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct ModeState {
401 pub counts: std::collections::HashMap<String, (DataValue, i64)>,
402}
403
404impl Default for ModeState {
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410impl ModeState {
411 #[must_use]
412 pub fn new() -> Self {
413 Self {
414 counts: std::collections::HashMap::new(),
415 }
416 }
417
418 pub fn add(&mut self, value: &DataValue) -> Result<()> {
419 if matches!(value, DataValue::Null) {
420 return Ok(());
421 }
422
423 let key = match value {
425 DataValue::String(s) => s.clone(),
426 DataValue::InternedString(s) => s.to_string(),
427 DataValue::Integer(i) => i.to_string(),
428 DataValue::Float(f) => f.to_string(),
429 DataValue::Boolean(b) => b.to_string(),
430 DataValue::DateTime(dt) => dt.to_string(),
431 DataValue::Null => return Ok(()),
432 };
433
434 let entry = self.counts.entry(key).or_insert((value.clone(), 0));
436 entry.1 += 1;
437
438 Ok(())
439 }
440
441 #[must_use]
442 pub fn finalize(self) -> DataValue {
443 if self.counts.is_empty() {
444 return DataValue::Null;
445 }
446
447 let max_entry = self.counts.iter().max_by_key(|(_, (_, count))| count);
449
450 match max_entry {
451 Some((_, (value, _count))) => value.clone(),
452 None => DataValue::Null,
453 }
454 }
455}
456
457pub trait AggregateFunction: Send + Sync {
459 fn name(&self) -> &str;
461
462 fn init(&self) -> AggregateState;
464
465 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
467
468 fn finalize(&self, state: AggregateState) -> DataValue;
470
471 fn requires_numeric(&self) -> bool {
473 false
474 }
475}
476
477#[derive(Debug, Clone)]
479pub struct StringAggState {
480 pub values: Vec<String>,
481 pub separator: String,
482}
483
484impl Default for StringAggState {
485 fn default() -> Self {
486 Self::new(",")
487 }
488}
489
490impl StringAggState {
491 #[must_use]
492 pub fn new(separator: &str) -> Self {
493 Self {
494 values: Vec::new(),
495 separator: separator.to_string(),
496 }
497 }
498
499 pub fn add(&mut self, value: &DataValue) -> Result<()> {
500 match value {
501 DataValue::Null => Ok(()), DataValue::String(s) => {
503 self.values.push(s.clone());
504 Ok(())
505 }
506 DataValue::InternedString(s) => {
507 self.values.push(s.to_string());
508 Ok(())
509 }
510 DataValue::Integer(n) => {
511 self.values.push(n.to_string());
512 Ok(())
513 }
514 DataValue::Float(f) => {
515 self.values.push(f.to_string());
516 Ok(())
517 }
518 DataValue::Boolean(b) => {
519 self.values.push(b.to_string());
520 Ok(())
521 }
522 DataValue::DateTime(dt) => {
523 self.values.push(dt.to_string());
524 Ok(())
525 }
526 }
527 }
528
529 #[must_use]
530 pub fn finalize(self) -> DataValue {
531 if self.values.is_empty() {
532 DataValue::Null
533 } else {
534 DataValue::String(self.values.join(&self.separator))
535 }
536 }
537}
538
539pub struct AggregateRegistry {
541 functions: Vec<Box<dyn AggregateFunction>>,
542}
543
544impl AggregateRegistry {
545 #[must_use]
546 pub fn new() -> Self {
547 use analytics::{
548 CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
549 RankFunction, SumsFunction,
550 };
551 use functions::{
552 AvgFunction, MaxFunction, MedianFunction, MinFunction, ModeFunction,
553 PercentileFunction, StdDevFunction, StdDevPopFunction, StdDevSampFunction,
554 StringAggFunction, VarPopFunction, VarSampFunction, VarianceFunction,
555 };
556
557 let functions: Vec<Box<dyn AggregateFunction>> = vec![
558 Box::new(AvgFunction),
562 Box::new(MinFunction),
563 Box::new(MaxFunction),
564 Box::new(StdDevFunction),
565 Box::new(StdDevPopFunction),
566 Box::new(StdDevSampFunction),
567 Box::new(VarianceFunction),
568 Box::new(VarPopFunction),
569 Box::new(VarSampFunction),
570 Box::new(MedianFunction),
571 Box::new(ModeFunction),
572 Box::new(PercentileFunction),
573 Box::new(StringAggFunction),
574 Box::new(DeltasFunction),
576 Box::new(SumsFunction),
577 Box::new(MavgFunction),
578 Box::new(PctChangeFunction),
579 Box::new(RankFunction),
580 Box::new(CumMaxFunction),
581 Box::new(CumMinFunction),
582 ];
583
584 Self { functions }
585 }
586
587 #[must_use]
588 pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
589 let name_upper = name.to_uppercase();
590 self.functions
591 .iter()
592 .find(|f| f.name() == name_upper)
593 .map(std::convert::AsRef::as_ref)
594 }
595
596 #[must_use]
597 pub fn is_aggregate(&self, name: &str) -> bool {
598 use crate::sql::aggregate_functions::AggregateFunctionRegistry;
599
600 if self.get(name).is_some() {
602 return true;
603 }
604
605 let new_registry = AggregateFunctionRegistry::new();
607 new_registry.contains(name)
608 }
609}
610
611impl Default for AggregateRegistry {
612 fn default() -> Self {
613 Self::new()
614 }
615}
616
617pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
619 use crate::recursive_parser::SqlExpression;
620 use crate::sql::aggregate_functions::AggregateFunctionRegistry;
621
622 match expr {
623 SqlExpression::FunctionCall { name, args, .. } => {
624 let registry = AggregateRegistry::new();
626 if registry.is_aggregate(name) {
627 return true;
628 }
629 let new_registry = AggregateFunctionRegistry::new();
631 if new_registry.contains(name) {
632 return true;
633 }
634 args.iter().any(contains_aggregate)
636 }
637 SqlExpression::BinaryOp { left, right, .. } => {
638 contains_aggregate(left) || contains_aggregate(right)
639 }
640 SqlExpression::Not { expr } => contains_aggregate(expr),
641 SqlExpression::CaseExpression {
642 when_branches,
643 else_branch,
644 } => {
645 when_branches.iter().any(|branch| {
646 contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
647 }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
648 }
649 _ => false,
650 }
651}
652
653pub fn is_constant_expression(expr: &crate::recursive_parser::SqlExpression) -> bool {
656 use crate::recursive_parser::SqlExpression;
657
658 match expr {
659 SqlExpression::StringLiteral(_) => true,
660 SqlExpression::NumberLiteral(_) => true,
661 SqlExpression::BooleanLiteral(_) => true,
662 SqlExpression::Null => true,
663 SqlExpression::DateTimeConstructor { .. } => true,
664 SqlExpression::DateTimeToday { .. } => true,
665 SqlExpression::BinaryOp { left, right, .. } => {
667 is_constant_expression(left) && is_constant_expression(right)
668 }
669 SqlExpression::Not { expr } => is_constant_expression(expr),
671 SqlExpression::CaseExpression {
673 when_branches,
674 else_branch,
675 } => {
676 when_branches.iter().all(|branch| {
677 is_constant_expression(&branch.condition) && is_constant_expression(&branch.result)
678 }) && else_branch
679 .as_ref()
680 .map_or(true, |e| is_constant_expression(e))
681 }
682 SqlExpression::FunctionCall { args, .. } => {
685 !contains_aggregate(expr) && args.iter().all(is_constant_expression)
687 }
688 _ => false,
689 }
690}
691
692pub fn is_aggregate_compatible(expr: &crate::recursive_parser::SqlExpression) -> bool {
695 contains_aggregate(expr) || is_constant_expression(expr)
696}