1use anyhow::{anyhow, Result};
7
8use crate::data::datatable::DataValue;
9
10pub mod analytics;
11pub mod functions;
12
13#[derive(Debug, Clone)]
15pub enum AggregateState {
16 Count(i64),
17 Sum(SumState),
18 Avg(AvgState),
19 MinMax(MinMaxState),
20 Variance(VarianceState),
21 CollectList(Vec<DataValue>),
22 Percentile(PercentileState),
23 Mode(ModeState),
24 Analytics(analytics::AnalyticsState),
25 StringAgg(StringAggState),
26}
27
28#[derive(Debug, Clone)]
30pub struct SumState {
31 pub int_sum: Option<i64>,
32 pub float_sum: Option<f64>,
33 pub has_values: bool,
34}
35
36impl Default for SumState {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl SumState {
43 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 int_sum: None,
47 float_sum: None,
48 has_values: false,
49 }
50 }
51
52 pub fn add(&mut self, value: &DataValue) -> Result<()> {
53 match value {
54 DataValue::Null => Ok(()), DataValue::Integer(n) => {
56 self.has_values = true;
57 if let Some(ref mut sum) = self.int_sum {
58 *sum = sum.saturating_add(*n);
59 } else if let Some(ref mut fsum) = self.float_sum {
60 *fsum += *n as f64;
61 } else {
62 self.int_sum = Some(*n);
63 }
64 Ok(())
65 }
66 DataValue::Float(f) => {
67 self.has_values = true;
68 if let Some(isum) = self.int_sum.take() {
70 self.float_sum = Some(isum as f64 + f);
71 } else if let Some(ref mut fsum) = self.float_sum {
72 *fsum += f;
73 } else {
74 self.float_sum = Some(*f);
75 }
76 Ok(())
77 }
78 _ => Err(anyhow!("Cannot sum non-numeric value")),
79 }
80 }
81
82 #[must_use]
83 pub fn finalize(self) -> DataValue {
84 if !self.has_values {
85 return DataValue::Null;
86 }
87
88 if let Some(fsum) = self.float_sum {
89 DataValue::Float(fsum)
90 } else if let Some(isum) = self.int_sum {
91 DataValue::Integer(isum)
92 } else {
93 DataValue::Null
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
100pub struct AvgState {
101 pub sum: SumState,
102 pub count: i64,
103}
104
105impl Default for AvgState {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111impl AvgState {
112 #[must_use]
113 pub fn new() -> Self {
114 Self {
115 sum: SumState::new(),
116 count: 0,
117 }
118 }
119
120 pub fn add(&mut self, value: &DataValue) -> Result<()> {
121 if !matches!(value, DataValue::Null) {
122 self.sum.add(value)?;
123 self.count += 1;
124 }
125 Ok(())
126 }
127
128 #[must_use]
129 pub fn finalize(self) -> DataValue {
130 if self.count == 0 {
131 return DataValue::Null;
132 }
133
134 let sum = self.sum.finalize();
135 match sum {
136 DataValue::Integer(n) => DataValue::Float(n as f64 / self.count as f64),
137 DataValue::Float(f) => DataValue::Float(f / self.count as f64),
138 _ => DataValue::Null,
139 }
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct MinMaxState {
146 pub is_min: bool,
147 pub current: Option<DataValue>,
148}
149
150impl MinMaxState {
151 #[must_use]
152 pub fn new(is_min: bool) -> Self {
153 Self {
154 is_min,
155 current: None,
156 }
157 }
158
159 pub fn add(&mut self, value: &DataValue) -> Result<()> {
160 if matches!(value, DataValue::Null) {
161 return Ok(());
162 }
163
164 if let Some(ref current) = self.current {
165 let should_update = if self.is_min {
166 value < current
167 } else {
168 value > current
169 };
170
171 if should_update {
172 self.current = Some(value.clone());
173 }
174 } else {
175 self.current = Some(value.clone());
176 }
177
178 Ok(())
179 }
180
181 #[must_use]
182 pub fn finalize(self) -> DataValue {
183 self.current.unwrap_or(DataValue::Null)
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct VarianceState {
190 pub sum: f64,
191 pub sum_of_squares: f64,
192 pub count: i64,
193}
194
195impl Default for VarianceState {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201impl VarianceState {
202 #[must_use]
203 pub fn new() -> Self {
204 Self {
205 sum: 0.0,
206 sum_of_squares: 0.0,
207 count: 0,
208 }
209 }
210
211 pub fn add(&mut self, value: &DataValue) -> Result<()> {
212 match value {
213 DataValue::Null => Ok(()), DataValue::Integer(n) => {
215 let f = *n as f64;
216 self.sum += f;
217 self.sum_of_squares += f * f;
218 self.count += 1;
219 Ok(())
220 }
221 DataValue::Float(f) => {
222 self.sum += f;
223 self.sum_of_squares += f * f;
224 self.count += 1;
225 Ok(())
226 }
227 _ => Err(anyhow!("Cannot compute variance of non-numeric value")),
228 }
229 }
230
231 #[must_use]
232 pub fn variance(&self) -> f64 {
233 if self.count <= 1 {
234 return 0.0;
235 }
236 let mean = self.sum / self.count as f64;
237 (self.sum_of_squares / self.count as f64) - (mean * mean)
238 }
239
240 #[must_use]
241 pub fn stddev(&self) -> f64 {
242 self.variance().sqrt()
243 }
244
245 #[must_use]
246 pub fn finalize_variance(self) -> DataValue {
247 if self.count == 0 {
248 DataValue::Null
249 } else {
250 DataValue::Float(self.variance())
251 }
252 }
253
254 #[must_use]
255 pub fn finalize_stddev(self) -> DataValue {
256 if self.count == 0 {
257 DataValue::Null
258 } else {
259 DataValue::Float(self.stddev())
260 }
261 }
262}
263
264#[derive(Debug, Clone)]
266pub struct PercentileState {
267 pub values: Vec<DataValue>,
268 pub percentile: f64,
269}
270
271impl Default for PercentileState {
272 fn default() -> Self {
273 Self::new(50.0) }
275}
276
277impl PercentileState {
278 #[must_use]
279 pub fn new(percentile: f64) -> Self {
280 Self {
281 values: Vec::new(),
282 percentile: percentile.clamp(0.0, 100.0),
283 }
284 }
285
286 pub fn add(&mut self, value: &DataValue) -> Result<()> {
287 if !matches!(value, DataValue::Null) {
288 self.values.push(value.clone());
289 }
290 Ok(())
291 }
292
293 #[must_use]
294 pub fn finalize(mut self) -> DataValue {
295 if self.values.is_empty() {
296 return DataValue::Null;
297 }
298
299 self.values.sort_by(|a, b| {
301 use std::cmp::Ordering;
302 match (a, b) {
303 (DataValue::Integer(a), DataValue::Integer(b)) => a.cmp(b),
304 (DataValue::Float(a), DataValue::Float(b)) => {
305 a.partial_cmp(b).unwrap_or(Ordering::Equal)
306 }
307 (DataValue::Integer(a), DataValue::Float(b)) => {
308 (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
309 }
310 (DataValue::Float(a), DataValue::Integer(b)) => {
311 a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
312 }
313 _ => Ordering::Equal,
314 }
315 });
316
317 let n = self.values.len();
318 if self.percentile == 0.0 {
319 return self.values[0].clone();
320 }
321 if self.percentile == 100.0 {
322 return self.values[n - 1].clone();
323 }
324
325 let pos = (self.percentile / 100.0) * ((n - 1) as f64);
327 let lower_idx = pos.floor() as usize;
328 let upper_idx = pos.ceil() as usize;
329
330 if lower_idx == upper_idx {
331 self.values[lower_idx].clone()
333 } else {
334 let fraction = pos - lower_idx as f64;
336 let lower_val = &self.values[lower_idx];
337 let upper_val = &self.values[upper_idx];
338
339 match (lower_val, upper_val) {
340 (DataValue::Integer(a), DataValue::Integer(b)) => {
341 let result = *a as f64 + fraction * (*b - *a) as f64;
342 if result.fract() == 0.0 {
343 DataValue::Integer(result as i64)
344 } else {
345 DataValue::Float(result)
346 }
347 }
348 (DataValue::Float(a), DataValue::Float(b)) => {
349 DataValue::Float(a + fraction * (b - a))
350 }
351 (DataValue::Integer(a), DataValue::Float(b)) => {
352 DataValue::Float(*a as f64 + fraction * (b - *a as f64))
353 }
354 (DataValue::Float(a), DataValue::Integer(b)) => {
355 DataValue::Float(a + fraction * (*b as f64 - a))
356 }
357 _ => lower_val.clone(),
359 }
360 }
361 }
362}
363
364#[derive(Debug, Clone)]
366pub struct ModeState {
367 pub counts: std::collections::HashMap<String, (DataValue, i64)>,
368}
369
370impl Default for ModeState {
371 fn default() -> Self {
372 Self::new()
373 }
374}
375
376impl ModeState {
377 #[must_use]
378 pub fn new() -> Self {
379 Self {
380 counts: std::collections::HashMap::new(),
381 }
382 }
383
384 pub fn add(&mut self, value: &DataValue) -> Result<()> {
385 if matches!(value, DataValue::Null) {
386 return Ok(());
387 }
388
389 let key = match value {
391 DataValue::String(s) => s.clone(),
392 DataValue::InternedString(s) => s.to_string(),
393 DataValue::Integer(i) => i.to_string(),
394 DataValue::Float(f) => f.to_string(),
395 DataValue::Boolean(b) => b.to_string(),
396 DataValue::DateTime(dt) => dt.to_string(),
397 DataValue::Null => return Ok(()),
398 };
399
400 let entry = self.counts.entry(key).or_insert((value.clone(), 0));
402 entry.1 += 1;
403
404 Ok(())
405 }
406
407 #[must_use]
408 pub fn finalize(self) -> DataValue {
409 if self.counts.is_empty() {
410 return DataValue::Null;
411 }
412
413 let max_entry = self.counts.iter().max_by_key(|(_, (_, count))| count);
415
416 match max_entry {
417 Some((_, (value, _count))) => value.clone(),
418 None => DataValue::Null,
419 }
420 }
421}
422
423pub trait AggregateFunction: Send + Sync {
425 fn name(&self) -> &str;
427
428 fn init(&self) -> AggregateState;
430
431 fn accumulate(&self, state: &mut AggregateState, value: &DataValue) -> Result<()>;
433
434 fn finalize(&self, state: AggregateState) -> DataValue;
436
437 fn requires_numeric(&self) -> bool {
439 false
440 }
441}
442
443#[derive(Debug, Clone)]
445pub struct StringAggState {
446 pub values: Vec<String>,
447 pub separator: String,
448}
449
450impl Default for StringAggState {
451 fn default() -> Self {
452 Self::new(",")
453 }
454}
455
456impl StringAggState {
457 #[must_use]
458 pub fn new(separator: &str) -> Self {
459 Self {
460 values: Vec::new(),
461 separator: separator.to_string(),
462 }
463 }
464
465 pub fn add(&mut self, value: &DataValue) -> Result<()> {
466 match value {
467 DataValue::Null => Ok(()), DataValue::String(s) => {
469 self.values.push(s.clone());
470 Ok(())
471 }
472 DataValue::InternedString(s) => {
473 self.values.push(s.to_string());
474 Ok(())
475 }
476 DataValue::Integer(n) => {
477 self.values.push(n.to_string());
478 Ok(())
479 }
480 DataValue::Float(f) => {
481 self.values.push(f.to_string());
482 Ok(())
483 }
484 DataValue::Boolean(b) => {
485 self.values.push(b.to_string());
486 Ok(())
487 }
488 DataValue::DateTime(dt) => {
489 self.values.push(dt.to_string());
490 Ok(())
491 }
492 }
493 }
494
495 #[must_use]
496 pub fn finalize(self) -> DataValue {
497 if self.values.is_empty() {
498 DataValue::Null
499 } else {
500 DataValue::String(self.values.join(&self.separator))
501 }
502 }
503}
504
505pub struct AggregateRegistry {
507 functions: Vec<Box<dyn AggregateFunction>>,
508}
509
510impl AggregateRegistry {
511 #[must_use]
512 pub fn new() -> Self {
513 use analytics::{
514 CumMaxFunction, CumMinFunction, DeltasFunction, MavgFunction, PctChangeFunction,
515 RankFunction, SumsFunction,
516 };
517 use functions::{
518 AvgFunction, CountFunction, CountStarFunction, MaxFunction, MedianFunction,
519 MinFunction, ModeFunction, PercentileFunction, StdDevFunction, StringAggFunction,
520 SumFunction, VarianceFunction,
521 };
522
523 let functions: Vec<Box<dyn AggregateFunction>> = vec![
524 Box::new(CountFunction),
525 Box::new(CountStarFunction),
526 Box::new(SumFunction),
527 Box::new(AvgFunction),
528 Box::new(MinFunction),
529 Box::new(MaxFunction),
530 Box::new(StdDevFunction),
531 Box::new(VarianceFunction),
532 Box::new(MedianFunction),
533 Box::new(ModeFunction),
534 Box::new(PercentileFunction),
535 Box::new(StringAggFunction),
536 Box::new(DeltasFunction),
538 Box::new(SumsFunction),
539 Box::new(MavgFunction),
540 Box::new(PctChangeFunction),
541 Box::new(RankFunction),
542 Box::new(CumMaxFunction),
543 Box::new(CumMinFunction),
544 ];
545
546 Self { functions }
547 }
548
549 #[must_use]
550 pub fn get(&self, name: &str) -> Option<&dyn AggregateFunction> {
551 let name_upper = name.to_uppercase();
552 self.functions
553 .iter()
554 .find(|f| f.name() == name_upper)
555 .map(std::convert::AsRef::as_ref)
556 }
557
558 #[must_use]
559 pub fn is_aggregate(&self, name: &str) -> bool {
560 self.get(name).is_some() || name.to_uppercase() == "COUNT" }
562}
563
564impl Default for AggregateRegistry {
565 fn default() -> Self {
566 Self::new()
567 }
568}
569
570pub fn contains_aggregate(expr: &crate::recursive_parser::SqlExpression) -> bool {
572 use crate::recursive_parser::SqlExpression;
573
574 match expr {
575 SqlExpression::FunctionCall { name, args, .. } => {
576 let registry = AggregateRegistry::new();
577 if registry.is_aggregate(name) {
578 return true;
579 }
580 args.iter().any(contains_aggregate)
582 }
583 SqlExpression::BinaryOp { left, right, .. } => {
584 contains_aggregate(left) || contains_aggregate(right)
585 }
586 SqlExpression::Not { expr } => contains_aggregate(expr),
587 SqlExpression::CaseExpression {
588 when_branches,
589 else_branch,
590 } => {
591 when_branches.iter().any(|branch| {
592 contains_aggregate(&branch.condition) || contains_aggregate(&branch.result)
593 }) || else_branch.as_ref().is_some_and(|e| contains_aggregate(e))
594 }
595 _ => false,
596 }
597}
598
599pub fn is_constant_expression(expr: &crate::recursive_parser::SqlExpression) -> bool {
602 use crate::recursive_parser::SqlExpression;
603
604 match expr {
605 SqlExpression::StringLiteral(_) => true,
606 SqlExpression::NumberLiteral(_) => true,
607 SqlExpression::BooleanLiteral(_) => true,
608 SqlExpression::Null => true,
609 SqlExpression::DateTimeConstructor { .. } => true,
610 SqlExpression::DateTimeToday { .. } => true,
611 SqlExpression::BinaryOp { left, right, .. } => {
613 is_constant_expression(left) && is_constant_expression(right)
614 }
615 SqlExpression::Not { expr } => is_constant_expression(expr),
617 SqlExpression::CaseExpression {
619 when_branches,
620 else_branch,
621 } => {
622 when_branches.iter().all(|branch| {
623 is_constant_expression(&branch.condition) && is_constant_expression(&branch.result)
624 }) && else_branch
625 .as_ref()
626 .map_or(true, |e| is_constant_expression(e))
627 }
628 SqlExpression::FunctionCall { args, .. } => {
631 !contains_aggregate(expr) && args.iter().all(is_constant_expression)
633 }
634 _ => false,
635 }
636}
637
638pub fn is_aggregate_compatible(expr: &crate::recursive_parser::SqlExpression) -> bool {
641 contains_aggregate(expr) || is_constant_expression(expr)
642}