1use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13#[derive(Debug, Clone)]
15pub enum AggregateState {
16 Count(i64),
17 Sum(SumState),
18 Avg(AvgState),
19 MinMax(MinMaxState),
20 Variance(VarianceState),
21 CollectList(Vec<DataValue>),
22 Percentile(PercentileState),
23 Mode(ModeState),
24 Analytics(analytics::AnalyticsState),
25 StringAgg(StringAggState),
26}
27
28#[derive(Debug, Clone)]
30pub struct SumState {
31 pub int_sum: Option<i64>,
32 pub float_sum: Option<f64>,
33 pub has_values: bool,
34}
35
36impl Default for SumState {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl SumState {
43 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 int_sum: None,
47 float_sum: None,
48 has_values: false,
49 }
50 }
51
52 pub fn add(&mut self, value: &DataValue) -> Result<()> {
53 match value {
54 DataValue::Null => Ok(()), DataValue::Integer(n) => {
56 self.has_values = true;
57 if let Some(ref mut sum) = self.int_sum {
58 *sum = sum.saturating_add(*n);
59 } else if let Some(ref mut fsum) = self.float_sum {
60 *fsum += *n as f64;
61 } else {
62 self.int_sum = Some(*n);
63 }
64 Ok(())
65 }
66 DataValue::Float(f) => {
67 self.has_values = true;
68 if let Some(isum) = self.int_sum.take() {
70 self.float_sum = Some(isum as f64 + f);
71 } else if let Some(ref mut fsum) = self.float_sum {
72 *fsum += f;
73 } else {
74 self.float_sum = Some(*f);
75 }
76 Ok(())
77 }
78 DataValue::Boolean(b) => {
79 let n = if *b { 1i64 } else { 0i64 };
82 self.has_values = true;
83 if let Some(ref mut sum) = self.int_sum {
84 *sum = sum.saturating_add(n);
85 } else if let Some(ref mut fsum) = self.float_sum {
86 *fsum += n as f64;
87 } else {
88 self.int_sum = Some(n);
89 }
90 Ok(())
91 }
92 _ => Err(anyhow!("Cannot sum non-numeric value")),
93 }
94 }
95
96 #[must_use]
97 pub fn finalize(self) -> DataValue {
98 if !self.has_values {
99 return DataValue::Null;
100 }
101
102 if let Some(fsum) = self.float_sum {
103 DataValue::Float(fsum)
104 } else if let Some(isum) = self.int_sum {
105 DataValue::Integer(isum)
106 } else {
107 DataValue::Null
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct AvgState {
115 pub sum: SumState,
116 pub count: i64,
117}
118
119impl Default for AvgState {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl AvgState {
126 #[must_use]
127 pub fn new() -> Self {
128 Self {
129 sum: SumState::new(),
130 count: 0,
131 }
132 }
133
134 pub fn add(&mut self, value: &DataValue) -> Result<()> {
135 if !matches!(value, DataValue::Null) {
136 self.sum.add(value)?;
137 self.count += 1;
138 }
139 Ok(())
140 }
141
142 #[must_use]
143 pub fn finalize(self) -> DataValue {
144 if self.count == 0 {
145 return DataValue::Null;
146 }
147
148 let sum = self.sum.finalize();
149 match sum {
150 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
151 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
152 _ => DataValue::Null,
153 }
154 }
155}
156
157#[derive(Debug, Clone)]
159pub struct MinMaxState {
160 pub is_min: bool,
161 pub current: Option<DataValue>,
162}
163
164impl MinMaxState {
165 #[must_use]
166 pub fn new(is_min: bool) -> Self {
167 Self {
168 is_min,
169 current: None,
170 }
171 }
172
173 pub fn add(&mut self, value: &DataValue) -> Result<()> {
174 if matches!(value, DataValue::Null) {
175 return Ok(());
176 }
177
178 if let Some(ref current) = self.current {
179 let should_update = if self.is_min {
180 value < current
181 } else {
182 value > current
183 };
184
185 if should_update {
186 self.current = Some(value.clone());
187 }
188 } else {
189 self.current = Some(value.clone());
190 }
191
192 Ok(())
193 }
194
195 #[must_use]
196 pub fn finalize(self) -> DataValue {
197 self.current.unwrap_or(DataValue::Null)
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct VarianceState {
204 pub sum: f64,
205 pub sum_of_squares: f64,
206 pub count: i64,
207}
208
209impl Default for VarianceState {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215impl VarianceState {
216 #[must_use]
217 pub fn new() -> Self {
218 Self {
219 sum: 0.0,
220 sum_of_squares: 0.0,
221 count: 0,
222 }
223 }
224
225 pub fn add(&mut self, value: &DataValue) -> Result<()> {
226 match value {
227 DataValue::Null => Ok(()), DataValue::Integer(n) => {
229 let f = *n as f64;
230 self.sum += f;
231 self.sum_of_squares += f * f;
232 self.count += 1;
233 Ok(())
234 }
235 DataValue::Float(f) => {
236 self.sum += f;
237 self.sum_of_squares += f * f;
238 self.count += 1;
239 Ok(())
240 }
241 _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
242 }
243 }
244
245 #[must_use]
246 pub fn variance(&self) -> f64 {
247 if self.count <= 1 {
248 return 0.0;
249 }
250 let mean = self.sum / self.count as f64;
251 (self.sum_of_squares / self.count as f64) - (mean * mean)
252 }
253
254 #[must_use]
255 pub fn stddev(&self) -> f64 {
256 self.variance().sqrt()
257 }
258
259 #[must_use]
260 pub fn finalize_variance(self) -> DataValue {
261 if self.count == 0 {
262 DataValue::Null
263 } else {
264 DataValue::Float(self.variance())
265 }
266 }
267
268 #[must_use]
269 pub fn finalize_stddev(self) -> DataValue {
270 if self.count == 0 {
271 DataValue::Null
272 } else {
273 DataValue::Float(self.stddev())
274 }
275 }
276
277 #[must_use]
278 pub fn variance_sample(&self) -> f64 {
279 if self.count <= 1 {
280 return 0.0;
281 }
282 let mean = self.sum / self.count as f64;
283 let variance_n = (self.sum_of_squares / self.count as f64) - (mean * mean);
284 variance_n * (self.count as f64 / (self.count - 1) as f64)
286 }
287
288 #[must_use]
289 pub fn stddev_sample(&self) -> f64 {
290 self.variance_sample().sqrt()
291 }
292
293 #[must_use]
294 pub fn finalize_variance_sample(self) -> DataValue {
295 if self.count <= 1 {
296 DataValue::Null
297 } else {
298 DataValue::Float(self.variance_sample())
299 }
300 }
301
302 #[must_use]
303 pub fn finalize_stddev_sample(self) -> DataValue {
304 if self.count <= 1 {
305 DataValue::Null
306 } else {
307 DataValue::Float(self.stddev_sample())
308 }
309 }
310}
311
312#[derive(Debug, Clone)]
314pub struct PercentileState {
315 pub values: Vec<DataValue>,
316 pub percentile: f64,
317}
318
319impl Default for PercentileState {
320 fn default() -> Self {
321 Self::new(50.0) }
323}
324
325impl PercentileState {
326 #[must_use]
327 pub fn new(percentile: f64) -> Self {
328 Self {
329 values: Vec::new(),
330 percentile: percentile.clamp(0.0, 100.0),
331 }
332 }
333
334 pub fn add(&mut self, value: &DataValue) -> Result<()> {
335 if !matches!(value, DataValue::Null) {
336 self.values.push(value.clone());
337 }
338 Ok(())
339 }
340
341 #[must_use]
342 pub fn finalize(mut self) -> DataValue {
343 if self.values.is_empty() {
344 return DataValue::Null;
345 }
346
347 self.values.sort_by(|a, b| {
349 use std::cmp::Ordering;
350 match (a, b) {
351 (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
352 (DataValue::Float(a), DataValue::Float(b)) => {
353 a.partial_cmp(b).unwrap_or(Ordering::Equal)
354 }
355 (DataValue::Integer(a), DataValue::Float(b)) => {
356 (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
357 }
358 (DataValue::Float(a), DataValue::Integer(b)) => {
359 a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
360 }
361 _ => Ordering::Equal,
362 }
363 });
364
365 let n = self.values.len();
366 if self.percentile == 0.0 {
367 return self.values[0].clone();
368 }
369 if self.percentile == 100.0 {
370 return self.values[n - 1].clone();
371 }
372
373 let pos = (self.percentile / 100.0) * ((n - 1) as f64);
375 let lower_idx = pos.floor() as usize;
376 let upper_idx = pos.ceil() as usize;
377
378 if lower_idx == upper_idx {
379 self.values[lower_idx].clone()
381 } else {
382 let fraction = pos - lower_idx as f64;
384 let lower_val = &self.values[lower_idx];
385 let upper_val = &self.values[upper_idx];
386
387 match (lower_val, upper_val) {
388 (DataValue::Integer(a), DataValue::Integer(b)) => {
389 let result = *a as f64 + fraction * (*b - *a) as f64;
390 if result.fract() == 0.0 {
391 DataValue::Integer(result as i64)
392 } else {
393 DataValue::Float(result)
394 }
395 }
396 (DataValue::Float(a), DataValue::Float(b)) => {
397 DataValue::Float(a + fraction * (b - a))
398 }
399 (DataValue::Integer(a), DataValue::Float(b)) => {
400 DataValue::Float(*a as f64 + fraction * (b - *a as f64))
401 }
402 (DataValue::Float(a), DataValue::Integer(b)) => {
403 DataValue::Float(a + fraction * (*b as f64 - a))
404 }
405 _ => lower_val.clone(),
407 }
408 }
409 }
410}
411
412#[derive(Debug, Clone)]
414pub struct ModeState {
415 pub counts: std::collections::HashMap<String, (DataValue, i64)>,
416}
417
418impl Default for ModeState {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424impl ModeState {
425 #[must_use]
426 pub fn new() -> Self {
427 Self {
428 counts: std::collections::HashMap::new(),
429 }
430 }
431
432 pub fn add(&mut self, value: &DataValue) -> Result<()> {
433 if matches!(value, DataValue::Null) {
434 return Ok(());
435 }
436
437 let key = match value {
439 DataValue::String(s) => s.clone(),
440 DataValue::InternedString(s) => s.to_string(),
441 DataValue::Integer(i) => i.to_string(),
442 DataValue::Float(f) => f.to_string(),
443 DataValue::Boolean(b) => b.to_string(),
444 DataValue::DateTime(dt) => dt.to_string(),
445 DataValue::Vector(v) => {
446 let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
447 format!("[{}]", components.join(","))
448 }
449 DataValue::Null => return Ok(()),
450 };
451
452 let entry = self.counts.entry(key).or_insert((value.clone(), 0));
454 entry.1 += 1;
455
456 Ok(())
457 }
458
459 #[must_use]
460 pub fn finalize(self) -> DataValue {
461 if self.counts.is_empty() {
462 return DataValue::Null;
463 }
464
465 let max_entry = self.counts.iter().max_by_key(|(_, (_, count))| count);
467
468 match max_entry {
469 Some((_, (value, _count))) => value.clone(),
470 None => DataValue::Null,
471 }
472 }
473}
474
475pub trait AggregateFunction: Send + Sync {
477 fn name(&self) -> &str;
479
480 fn init(&self) -> AggregateState;
482
483 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
485
486 fn finalize(&self, state: AggregateState) -> DataValue;
488
489 fn requires_numeric(&self) -> bool {
491 false
492 }
493}
494
495#[derive(Debug, Clone)]
497pub struct StringAggState {
498 pub values: Vec<String>,
499 pub separator: String,
500}
501
502impl Default for StringAggState {
503 fn default() -> Self {
504 Self::new(",")
505 }
506}
507
508impl StringAggState {
509 #[must_use]
510 pub fn new(separator: &str) -> Self {
511 Self {
512 values: Vec::new(),
513 separator: separator.to_string(),
514 }
515 }
516
517 pub fn add(&mut self, value: &DataValue) -> Result<()> {
518 match value {
519 DataValue::Null => Ok(()), DataValue::String(s) => {
521 self.values.push(s.clone());
522 Ok(())
523 }
524 DataValue::InternedString(s) => {
525 self.values.push(s.to_string());
526 Ok(())
527 }
528 DataValue::Integer(n) => {
529 self.values.push(n.to_string());
530 Ok(())
531 }
532 DataValue::Float(f) => {
533 self.values.push(f.to_string());
534 Ok(())
535 }
536 DataValue::Boolean(b) => {
537 self.values.push(b.to_string());
538 Ok(())
539 }
540 DataValue::DateTime(dt) => {
541 self.values.push(dt.to_string());
542 Ok(())
543 }
544 DataValue::Vector(v) => {
545 let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
546 self.values.push(format!("[{}]", components.join(",")));
547 Ok(())
548 }
549 }
550 }
551
552 #[must_use]
553 pub fn finalize(self) -> DataValue {
554 if self.values.is_empty() {
555 DataValue::Null
556 } else {
557 DataValue::String(self.values.join(&self.separator))
558 }
559 }
560}
561
562pub struct AggregateRegistry {
564 functions: Vec<Box<dyn AggregateFunction>>,
565}
566
567impl AggregateRegistry {
568 #[must_use]
569 pub fn new() -> Self {
570 use analytics::{
571 CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
572 RankFunction, SumsFunction,
573 };
574 use functions::{
575 AvgFunction, MaxFunction, MedianFunction, MinFunction, ModeFunction,
576 PercentileFunction, StdDevFunction, StdDevPopFunction, StdDevSampFunction,
577 StringAggFunction, VarPopFunction, VarSampFunction, VarianceFunction,
578 };
579
580 let functions: Vec<Box<dyn AggregateFunction>> = vec![
581 Box::new(AvgFunction),
585 Box::new(MinFunction),
586 Box::new(MaxFunction),
587 Box::new(StdDevFunction),
588 Box::new(StdDevPopFunction),
589 Box::new(StdDevSampFunction),
590 Box::new(VarianceFunction),
591 Box::new(VarPopFunction),
592 Box::new(VarSampFunction),
593 Box::new(MedianFunction),
594 Box::new(ModeFunction),
595 Box::new(PercentileFunction),
596 Box::new(StringAggFunction),
597 Box::new(DeltasFunction),
599 Box::new(SumsFunction),
600 Box::new(MavgFunction),
601 Box::new(PctChangeFunction),
602 Box::new(RankFunction),
603 Box::new(CumMaxFunction),
604 Box::new(CumMinFunction),
605 ];
606
607 Self { functions }
608 }
609
610 #[must_use]
611 pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
612 let name_upper = name.to_uppercase();
613 self.functions
614 .iter()
615 .find(|f| f.name() == name_upper)
616 .map(std::convert::AsRef::as_ref)
617 }
618
619 #[must_use]
620 pub fn is_aggregate(&self, name: &str) -> bool {
621 use crate::sql::aggregate_functions::AggregateFunctionRegistry;
622
623 if self.get(name).is_some() {
625 return true;
626 }
627
628 let new_registry = AggregateFunctionRegistry::new();
630 new_registry.contains(name)
631 }
632}
633
634impl Default for AggregateRegistry {
635 fn default() -> Self {
636 Self::new()
637 }
638}
639
640pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
642 use crate::recursive_parser::SqlExpression;
643 use crate::sql::aggregate_functions::AggregateFunctionRegistry;
644
645 match expr {
646 SqlExpression::FunctionCall { name, args, .. } => {
647 let registry = AggregateRegistry::new();
649 if registry.is_aggregate(name) {
650 return true;
651 }
652 let new_registry = AggregateFunctionRegistry::new();
654 if new_registry.contains(name) {
655 return true;
656 }
657 args.iter().any(contains_aggregate)
659 }
660 SqlExpression::BinaryOp { left, right, .. } => {
661 contains_aggregate(left) || contains_aggregate(right)
662 }
663 SqlExpression::Not { expr } => contains_aggregate(expr),
664 SqlExpression::CaseExpression {
665 when_branches,
666 else_branch,
667 } => {
668 when_branches.iter().any(|branch| {
669 contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
670 }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
671 }
672 _ => false,
673 }
674}
675
676pub fn is_constant_expression(expr: &crate::recursive_parser::SqlExpression) -> bool {
679 use crate::recursive_parser::SqlExpression;
680
681 match expr {
682 SqlExpression::StringLiteral(_) => true,
683 SqlExpression::NumberLiteral(_) => true,
684 SqlExpression::BooleanLiteral(_) => true,
685 SqlExpression::Null => true,
686 SqlExpression::DateTimeConstructor { .. } => true,
687 SqlExpression::DateTimeToday { .. } => true,
688 SqlExpression::BinaryOp { left, right, .. } => {
690 is_constant_expression(left) && is_constant_expression(right)
691 }
692 SqlExpression::Not { expr } => is_constant_expression(expr),
694 SqlExpression::CaseExpression {
696 when_branches,
697 else_branch,
698 } => {
699 when_branches.iter().all(|branch| {
700 is_constant_expression(&branch.condition) && is_constant_expression(&branch.result)
701 }) && else_branch
702 .as_ref()
703 .map_or(true, |e| is_constant_expression(e))
704 }
705 SqlExpression::FunctionCall { args, .. } => {
708 !contains_aggregate(expr) && args.iter().all(is_constant_expression)
710 }
711 _ => false,
712 }
713}
714
715pub fn is_aggregate_compatible(expr: &crate::recursive_parser::SqlExpression) -> bool {
718 contains_aggregate(expr) || is_constant_expression(expr)
719}