1use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13#[derive(Debug, Clone)]
15pub enum AggregateState {
16 Count(i64),
17 Sum(SumState),
18 Avg(AvgState),
19 MinMax(MinMaxState),
20 Variance(VarianceState),
21 CollectList(Vec<DataValue>),
22 Percentile(PercentileState),
23 Mode(ModeState),
24 Analytics(analytics::AnalyticsState),
25 StringAgg(StringAggState),
26}
27
28#[derive(Debug, Clone)]
30pub struct SumState {
31 pub int_sum: Option<i64>,
32 pub float_sum: Option<f64>,
33 pub has_values: bool,
34}
35
36impl Default for SumState {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl SumState {
43 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 int_sum: None,
47 float_sum: None,
48 has_values: false,
49 }
50 }
51
52 pub fn add(&mut self, value: &DataValue) -> Result<()> {
53 match value {
54 DataValue::Null => Ok(()), DataValue::Integer(n) => {
56 self.has_values = true;
57 if let Some(ref mut sum) = self.int_sum {
58 *sum = sum.saturating_add(*n);
59 } else if let Some(ref mut fsum) = self.float_sum {
60 *fsum += *n as f64;
61 } else {
62 self.int_sum = Some(*n);
63 }
64 Ok(())
65 }
66 DataValue::Float(f) => {
67 self.has_values = true;
68 if let Some(isum) = self.int_sum.take() {
70 self.float_sum = Some(isum as f64 + f);
71 } else if let Some(ref mut fsum) = self.float_sum {
72 *fsum += f;
73 } else {
74 self.float_sum = Some(*f);
75 }
76 Ok(())
77 }
78 _ => Err(anyhow!("Cannot sum non-numeric value")),
79 }
80 }
81
82 #[must_use]
83 pub fn finalize(self) -> DataValue {
84 if !self.has_values {
85 return DataValue::Null;
86 }
87
88 if let Some(fsum) = self.float_sum {
89 DataValue::Float(fsum)
90 } else if let Some(isum) = self.int_sum {
91 DataValue::Integer(isum)
92 } else {
93 DataValue::Null
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
100pub struct AvgState {
101 pub sum: SumState,
102 pub count: i64,
103}
104
105impl Default for AvgState {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111impl AvgState {
112 #[must_use]
113 pub fn new() -> Self {
114 Self {
115 sum: SumState::new(),
116 count: 0,
117 }
118 }
119
120 pub fn add(&mut self, value: &DataValue) -> Result<()> {
121 if !matches!(value, DataValue::Null) {
122 self.sum.add(value)?;
123 self.count += 1;
124 }
125 Ok(())
126 }
127
128 #[must_use]
129 pub fn finalize(self) -> DataValue {
130 if self.count == 0 {
131 return DataValue::Null;
132 }
133
134 let sum = self.sum.finalize();
135 match sum {
136 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
137 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
138 _ => DataValue::Null,
139 }
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct MinMaxState {
146 pub is_min: bool,
147 pub current: Option<DataValue>,
148}
149
150impl MinMaxState {
151 #[must_use]
152 pub fn new(is_min: bool) -> Self {
153 Self {
154 is_min,
155 current: None,
156 }
157 }
158
159 pub fn add(&mut self, value: &DataValue) -> Result<()> {
160 if matches!(value, DataValue::Null) {
161 return Ok(());
162 }
163
164 if let Some(ref current) = self.current {
165 let should_update = if self.is_min {
166 value < current
167 } else {
168 value > current
169 };
170
171 if should_update {
172 self.current = Some(value.clone());
173 }
174 } else {
175 self.current = Some(value.clone());
176 }
177
178 Ok(())
179 }
180
181 #[must_use]
182 pub fn finalize(self) -> DataValue {
183 self.current.unwrap_or(DataValue::Null)
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct VarianceState {
190 pub sum: f64,
191 pub sum_of_squares: f64,
192 pub count: i64,
193}
194
195impl Default for VarianceState {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201impl VarianceState {
202 #[must_use]
203 pub fn new() -> Self {
204 Self {
205 sum: 0.0,
206 sum_of_squares: 0.0,
207 count: 0,
208 }
209 }
210
211 pub fn add(&mut self, value: &DataValue) -> Result<()> {
212 match value {
213 DataValue::Null => Ok(()), DataValue::Integer(n) => {
215 let f = *n as f64;
216 self.sum += f;
217 self.sum_of_squares += f * f;
218 self.count += 1;
219 Ok(())
220 }
221 DataValue::Float(f) => {
222 self.sum += f;
223 self.sum_of_squares += f * f;
224 self.count += 1;
225 Ok(())
226 }
227 _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
228 }
229 }
230
231 #[must_use]
232 pub fn variance(&self) -> f64 {
233 if self.count <= 1 {
234 return 0.0;
235 }
236 let mean = self.sum / self.count as f64;
237 (self.sum_of_squares / self.count as f64) - (mean * mean)
238 }
239
240 #[must_use]
241 pub fn stddev(&self) -> f64 {
242 self.variance().sqrt()
243 }
244
245 #[must_use]
246 pub fn finalize_variance(self) -> DataValue {
247 if self.count == 0 {
248 DataValue::Null
249 } else {
250 DataValue::Float(self.variance())
251 }
252 }
253
254 #[must_use]
255 pub fn finalize_stddev(self) -> DataValue {
256 if self.count == 0 {
257 DataValue::Null
258 } else {
259 DataValue::Float(self.stddev())
260 }
261 }
262
263 #[must_use]
264 pub fn variance_sample(&self) -> f64 {
265 if self.count <= 1 {
266 return 0.0;
267 }
268 let mean = self.sum / self.count as f64;
269 let variance_n = (self.sum_of_squares / self.count as f64) - (mean * mean);
270 variance_n * (self.count as f64 / (self.count - 1) as f64)
272 }
273
274 #[must_use]
275 pub fn stddev_sample(&self) -> f64 {
276 self.variance_sample().sqrt()
277 }
278
279 #[must_use]
280 pub fn finalize_variance_sample(self) -> DataValue {
281 if self.count <= 1 {
282 DataValue::Null
283 } else {
284 DataValue::Float(self.variance_sample())
285 }
286 }
287
288 #[must_use]
289 pub fn finalize_stddev_sample(self) -> DataValue {
290 if self.count <= 1 {
291 DataValue::Null
292 } else {
293 DataValue::Float(self.stddev_sample())
294 }
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct PercentileState {
301 pub values: Vec<DataValue>,
302 pub percentile: f64,
303}
304
305impl Default for PercentileState {
306 fn default() -> Self {
307 Self::new(50.0) }
309}
310
311impl PercentileState {
312 #[must_use]
313 pub fn new(percentile: f64) -> Self {
314 Self {
315 values: Vec::new(),
316 percentile: percentile.clamp(0.0, 100.0),
317 }
318 }
319
320 pub fn add(&mut self, value: &DataValue) -> Result<()> {
321 if !matches!(value, DataValue::Null) {
322 self.values.push(value.clone());
323 }
324 Ok(())
325 }
326
327 #[must_use]
328 pub fn finalize(mut self) -> DataValue {
329 if self.values.is_empty() {
330 return DataValue::Null;
331 }
332
333 self.values.sort_by(|a, b| {
335 use std::cmp::Ordering;
336 match (a, b) {
337 (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
338 (DataValue::Float(a), DataValue::Float(b)) => {
339 a.partial_cmp(b).unwrap_or(Ordering::Equal)
340 }
341 (DataValue::Integer(a), DataValue::Float(b)) => {
342 (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
343 }
344 (DataValue::Float(a), DataValue::Integer(b)) => {
345 a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
346 }
347 _ => Ordering::Equal,
348 }
349 });
350
351 let n = self.values.len();
352 if self.percentile == 0.0 {
353 return self.values[0].clone();
354 }
355 if self.percentile == 100.0 {
356 return self.values[n - 1].clone();
357 }
358
359 let pos = (self.percentile / 100.0) * ((n - 1) as f64);
361 let lower_idx = pos.floor() as usize;
362 let upper_idx = pos.ceil() as usize;
363
364 if lower_idx == upper_idx {
365 self.values[lower_idx].clone()
367 } else {
368 let fraction = pos - lower_idx as f64;
370 let lower_val = &self.values[lower_idx];
371 let upper_val = &self.values[upper_idx];
372
373 match (lower_val, upper_val) {
374 (DataValue::Integer(a), DataValue::Integer(b)) => {
375 let result = *a as f64 + fraction * (*b - *a) as f64;
376 if result.fract() == 0.0 {
377 DataValue::Integer(result as i64)
378 } else {
379 DataValue::Float(result)
380 }
381 }
382 (DataValue::Float(a), DataValue::Float(b)) => {
383 DataValue::Float(a + fraction * (b - a))
384 }
385 (DataValue::Integer(a), DataValue::Float(b)) => {
386 DataValue::Float(*a as f64 + fraction * (b - *a as f64))
387 }
388 (DataValue::Float(a), DataValue::Integer(b)) => {
389 DataValue::Float(a + fraction * (*b as f64 - a))
390 }
391 _ => lower_val.clone(),
393 }
394 }
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct ModeState {
401 pub counts: std::collections::HashMap<String, (DataValue, i64)>,
402}
403
404impl Default for ModeState {
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410impl ModeState {
411 #[must_use]
412 pub fn new() -> Self {
413 Self {
414 counts: std::collections::HashMap::new(),
415 }
416 }
417
418 pub fn add(&mut self, value: &DataValue) -> Result<()> {
419 if matches!(value, DataValue::Null) {
420 return Ok(());
421 }
422
423 let key = match value {
425 DataValue::String(s) => s.clone(),
426 DataValue::InternedString(s) => s.to_string(),
427 DataValue::Integer(i) => i.to_string(),
428 DataValue::Float(f) => f.to_string(),
429 DataValue::Boolean(b) => b.to_string(),
430 DataValue::DateTime(dt) => dt.to_string(),
431 DataValue::Vector(v) => {
432 let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
433 format!("[{}]", components.join(","))
434 }
435 DataValue::Null => return Ok(()),
436 };
437
438 let entry = self.counts.entry(key).or_insert((value.clone(), 0));
440 entry.1 += 1;
441
442 Ok(())
443 }
444
445 #[must_use]
446 pub fn finalize(self) -> DataValue {
447 if self.counts.is_empty() {
448 return DataValue::Null;
449 }
450
451 let max_entry = self.counts.iter().max_by_key(|(_, (_, count))| count);
453
454 match max_entry {
455 Some((_, (value, _count))) => value.clone(),
456 None => DataValue::Null,
457 }
458 }
459}
460
461pub trait AggregateFunction: Send + Sync {
463 fn name(&self) -> &str;
465
466 fn init(&self) -> AggregateState;
468
469 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
471
472 fn finalize(&self, state: AggregateState) -> DataValue;
474
475 fn requires_numeric(&self) -> bool {
477 false
478 }
479}
480
481#[derive(Debug, Clone)]
483pub struct StringAggState {
484 pub values: Vec<String>,
485 pub separator: String,
486}
487
488impl Default for StringAggState {
489 fn default() -> Self {
490 Self::new(",")
491 }
492}
493
494impl StringAggState {
495 #[must_use]
496 pub fn new(separator: &str) -> Self {
497 Self {
498 values: Vec::new(),
499 separator: separator.to_string(),
500 }
501 }
502
503 pub fn add(&mut self, value: &DataValue) -> Result<()> {
504 match value {
505 DataValue::Null => Ok(()), DataValue::String(s) => {
507 self.values.push(s.clone());
508 Ok(())
509 }
510 DataValue::InternedString(s) => {
511 self.values.push(s.to_string());
512 Ok(())
513 }
514 DataValue::Integer(n) => {
515 self.values.push(n.to_string());
516 Ok(())
517 }
518 DataValue::Float(f) => {
519 self.values.push(f.to_string());
520 Ok(())
521 }
522 DataValue::Boolean(b) => {
523 self.values.push(b.to_string());
524 Ok(())
525 }
526 DataValue::DateTime(dt) => {
527 self.values.push(dt.to_string());
528 Ok(())
529 }
530 DataValue::Vector(v) => {
531 let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
532 self.values.push(format!("[{}]", components.join(",")));
533 Ok(())
534 }
535 }
536 }
537
538 #[must_use]
539 pub fn finalize(self) -> DataValue {
540 if self.values.is_empty() {
541 DataValue::Null
542 } else {
543 DataValue::String(self.values.join(&self.separator))
544 }
545 }
546}
547
548pub struct AggregateRegistry {
550 functions: Vec<Box<dyn AggregateFunction>>,
551}
552
553impl AggregateRegistry {
554 #[must_use]
555 pub fn new() -> Self {
556 use analytics::{
557 CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
558 RankFunction, SumsFunction,
559 };
560 use functions::{
561 AvgFunction, MaxFunction, MedianFunction, MinFunction, ModeFunction,
562 PercentileFunction, StdDevFunction, StdDevPopFunction, StdDevSampFunction,
563 StringAggFunction, VarPopFunction, VarSampFunction, VarianceFunction,
564 };
565
566 let functions: Vec<Box<dyn AggregateFunction>> = vec![
567 Box::new(AvgFunction),
571 Box::new(MinFunction),
572 Box::new(MaxFunction),
573 Box::new(StdDevFunction),
574 Box::new(StdDevPopFunction),
575 Box::new(StdDevSampFunction),
576 Box::new(VarianceFunction),
577 Box::new(VarPopFunction),
578 Box::new(VarSampFunction),
579 Box::new(MedianFunction),
580 Box::new(ModeFunction),
581 Box::new(PercentileFunction),
582 Box::new(StringAggFunction),
583 Box::new(DeltasFunction),
585 Box::new(SumsFunction),
586 Box::new(MavgFunction),
587 Box::new(PctChangeFunction),
588 Box::new(RankFunction),
589 Box::new(CumMaxFunction),
590 Box::new(CumMinFunction),
591 ];
592
593 Self { functions }
594 }
595
596 #[must_use]
597 pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
598 let name_upper = name.to_uppercase();
599 self.functions
600 .iter()
601 .find(|f| f.name() == name_upper)
602 .map(std::convert::AsRef::as_ref)
603 }
604
605 #[must_use]
606 pub fn is_aggregate(&self, name: &str) -> bool {
607 use crate::sql::aggregate_functions::AggregateFunctionRegistry;
608
609 if self.get(name).is_some() {
611 return true;
612 }
613
614 let new_registry = AggregateFunctionRegistry::new();
616 new_registry.contains(name)
617 }
618}
619
620impl Default for AggregateRegistry {
621 fn default() -> Self {
622 Self::new()
623 }
624}
625
626pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
628 use crate::recursive_parser::SqlExpression;
629 use crate::sql::aggregate_functions::AggregateFunctionRegistry;
630
631 match expr {
632 SqlExpression::FunctionCall { name, args, .. } => {
633 let registry = AggregateRegistry::new();
635 if registry.is_aggregate(name) {
636 return true;
637 }
638 let new_registry = AggregateFunctionRegistry::new();
640 if new_registry.contains(name) {
641 return true;
642 }
643 args.iter().any(contains_aggregate)
645 }
646 SqlExpression::BinaryOp { left, right, .. } => {
647 contains_aggregate(left) || contains_aggregate(right)
648 }
649 SqlExpression::Not { expr } => contains_aggregate(expr),
650 SqlExpression::CaseExpression {
651 when_branches,
652 else_branch,
653 } => {
654 when_branches.iter().any(|branch| {
655 contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
656 }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
657 }
658 _ => false,
659 }
660}
661
662pub fn is_constant_expression(expr: &crate::recursive_parser::SqlExpression) -> bool {
665 use crate::recursive_parser::SqlExpression;
666
667 match expr {
668 SqlExpression::StringLiteral(_) => true,
669 SqlExpression::NumberLiteral(_) => true,
670 SqlExpression::BooleanLiteral(_) => true,
671 SqlExpression::Null => true,
672 SqlExpression::DateTimeConstructor { .. } => true,
673 SqlExpression::DateTimeToday { .. } => true,
674 SqlExpression::BinaryOp { left, right, .. } => {
676 is_constant_expression(left) && is_constant_expression(right)
677 }
678 SqlExpression::Not { expr } => is_constant_expression(expr),
680 SqlExpression::CaseExpression {
682 when_branches,
683 else_branch,
684 } => {
685 when_branches.iter().all(|branch| {
686 is_constant_expression(&branch.condition) && is_constant_expression(&branch.result)
687 }) && else_branch
688 .as_ref()
689 .map_or(true, |e| is_constant_expression(e))
690 }
691 SqlExpression::FunctionCall { args, .. } => {
694 !contains_aggregate(expr) && args.iter().all(is_constant_expression)
696 }
697 _ => false,
698 }
699}
700
701pub fn is_aggregate_compatible(expr: &crate::recursive_parser::SqlExpression) -> bool {
704 contains_aggregate(expr) || is_constant_expression(expr)
705}