1use alloc::boxed::Box;
24use alloc::collections::BTreeMap;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
65 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
66 contains_aggregate(expr)
67 }
68 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
69 Expr::Extract { source, .. } => contains_aggregate(source),
70 Expr::ScalarSubquery(_)
75 | Expr::Exists { .. }
76 | Expr::InSubquery { .. }
77 | Expr::WindowFunction { .. }
78 | Expr::Literal(_)
79 | Expr::Placeholder(_)
80 | Expr::Column(_) => false,
81 Expr::Array(items) => items.iter().any(contains_aggregate),
85 Expr::ArraySubscript { target, index } => {
86 contains_aggregate(target) || contains_aggregate(index)
87 }
88 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
89 Expr::Case {
92 operand,
93 branches,
94 else_branch,
95 } => {
96 operand.as_deref().is_some_and(contains_aggregate)
97 || branches
98 .iter()
99 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
100 || else_branch.as_deref().is_some_and(contains_aggregate)
101 }
102 }
103}
104
105pub fn is_aggregate_name(name: &str) -> bool {
106 matches!(
107 name.to_ascii_lowercase().as_str(),
108 "count" | "count_star" | "sum" | "min" | "max" | "avg"
109 )
110}
111
112#[derive(Debug, Default, Clone)]
114struct AggState {
115 count: i64,
116 sum_int: i64,
117 sum_float: f64,
118 extreme: Option<Value>,
119 use_float: bool,
120}
121
122#[derive(Debug, Clone)]
123struct AggSpec {
124 name: String, arg: Option<Expr>,
127}
128
129#[derive(Debug)]
132pub struct AggResult {
133 pub columns: Vec<ColumnSchema>,
134 pub rows: Vec<Row>,
135}
136
137#[allow(clippy::too_many_lines)]
140pub fn run(
141 stmt: &SelectStatement,
142 rows: &[&Row],
143 schema_cols: &[ColumnSchema],
144 table_alias: Option<&str>,
145) -> Result<AggResult, EvalError> {
146 let ctx = EvalContext::new(schema_cols, table_alias);
147 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
148
149 let mut agg_specs: Vec<AggSpec> = Vec::new();
151 for item in &stmt.items {
152 if let SelectItem::Expr { expr, .. } = item {
153 collect_aggregates(expr, &mut agg_specs);
154 }
155 }
156 for o in &stmt.order_by {
157 collect_aggregates(&o.expr, &mut agg_specs);
158 }
159 if let Some(h) = &stmt.having {
160 collect_aggregates(h, &mut agg_specs);
161 }
162
163 let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
166 let mut key_order: Vec<String> = Vec::new();
167 if rows.is_empty() && group_exprs.is_empty() {
170 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
172 groups.insert(String::new(), (Vec::new(), init));
173 key_order.push(String::new());
174 }
175
176 for row in rows {
177 let group_vals: Vec<Value> = group_exprs
178 .iter()
179 .map(|g| eval::eval_expr(g, row, &ctx))
180 .collect::<Result<_, _>>()?;
181 let key = encode_key(&group_vals);
182 let entry = groups.entry(key.clone()).or_insert_with(|| {
183 key_order.push(key.clone());
184 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
185 (group_vals.clone(), init)
186 });
187 for (i, spec) in agg_specs.iter().enumerate() {
188 let arg_val = match &spec.arg {
189 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
191 };
192 update_state(&mut entry.1[i], &spec.name, &arg_val)?;
193 }
194 }
195
196 let group_types: Vec<DataType> = if rows.is_empty() {
198 group_exprs.iter().map(|_| DataType::Text).collect()
201 } else {
202 let probe = rows[0];
203 group_exprs
204 .iter()
205 .map(|g| {
206 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
207 })
208 .collect::<Result<_, _>>()?
209 };
210 let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
211 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
212 for (i, ty) in group_types.iter().enumerate() {
213 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
214 }
215 for (i, ty) in agg_types.iter().enumerate() {
216 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
217 }
218
219 let mut synth_rows: Vec<Row> = Vec::new();
221 for k in &key_order {
222 let (gvals, states) = &groups[k];
223 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
224 values.extend(gvals.iter().cloned());
225 for (i, st) in states.iter().enumerate() {
226 values.push(finalize(&agg_specs[i].name, st));
227 }
228 synth_rows.push(Row::new(values));
229 }
230
231 let columns: Vec<ColumnSchema> = stmt
236 .items
237 .iter()
238 .map(|item| match item {
239 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
240 detail: "SELECT * with aggregates is not supported".into(),
241 }),
242 SelectItem::Expr { expr, alias } => {
243 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
244 let name = alias.clone().unwrap_or_else(|| expr.to_string());
245 Ok(ColumnSchema::new(
246 name,
247 agg_or_group_type(&rewritten, &synth_schema),
248 true,
249 ))
250 }
251 })
252 .collect::<Result<_, _>>()?;
253
254 let synth_ctx = EvalContext::new(&synth_schema, None);
259 let having_rewritten = stmt
260 .having
261 .as_ref()
262 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
263 let mut kept_synth: Vec<Row> = Vec::new();
264 let mut out_rows: Vec<Row> = Vec::new();
265 for srow in synth_rows {
266 if let Some(h) = &having_rewritten {
267 let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
268 if !matches!(cond, Value::Bool(true)) {
269 continue;
270 }
271 }
272 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
273 for item in &stmt.items {
274 if let SelectItem::Expr { expr, .. } = item {
275 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
276 values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
277 }
278 }
279 kept_synth.push(srow);
280 out_rows.push(Row::new(values));
281 }
282
283 if !stmt.order_by.is_empty() {
286 let rewritten: Vec<Expr> = stmt
289 .order_by
290 .iter()
291 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
292 .collect();
293 let descs: Vec<bool> = stmt.order_by.iter().map(|o| o.desc).collect();
294 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
295 .into_iter()
296 .zip(out_rows)
297 .map(|(s, o)| {
298 let mut keys = Vec::with_capacity(rewritten.len());
299 for e in &rewritten {
300 keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
301 }
302 Ok::<_, EvalError>((keys, o))
303 })
304 .collect::<Result<_, _>>()?;
305 tagged.sort_by(|a, b| {
306 use core::cmp::Ordering;
307 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
308 let cmp = value_cmp(ka, kb);
309 let cmp = if descs[i] { cmp.reverse() } else { cmp };
310 if cmp != Ordering::Equal {
311 return cmp;
312 }
313 }
314 Ordering::Equal
315 });
316 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
317 }
318
319 Ok(AggResult {
320 columns,
321 rows: out_rows,
322 })
323}
324
325fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
326 match e {
327 Expr::FunctionCall { name, args } => {
328 let lower = name.to_ascii_lowercase();
329 if is_aggregate_name(&lower) {
330 let arg = if lower == "count_star" {
331 None
332 } else {
333 args.first().cloned()
334 };
335 let spec = AggSpec {
336 name: lower,
337 arg: arg.clone(),
338 };
339 if !out.iter().any(|s| s.name == spec.name && s.arg == spec.arg) {
340 out.push(spec);
341 }
342 } else {
345 for a in args {
346 collect_aggregates(a, out);
347 }
348 }
349 }
350 Expr::Binary { lhs, rhs, .. } => {
351 collect_aggregates(lhs, out);
352 collect_aggregates(rhs, out);
353 }
354 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
355 collect_aggregates(expr, out);
356 }
357 Expr::Like { expr, pattern, .. } => {
358 collect_aggregates(expr, out);
359 collect_aggregates(pattern, out);
360 }
361 Expr::Extract { source, .. } => collect_aggregates(source, out),
362 Expr::ScalarSubquery(_)
365 | Expr::Exists { .. }
366 | Expr::InSubquery { .. }
367 | Expr::WindowFunction { .. }
368 | Expr::Literal(_)
369 | Expr::Placeholder(_)
370 | Expr::Column(_) => {}
371 Expr::Array(items) => {
374 for elem in items {
375 collect_aggregates(elem, out);
376 }
377 }
378 Expr::ArraySubscript { target, index } => {
379 collect_aggregates(target, out);
380 collect_aggregates(index, out);
381 }
382 Expr::AnyAll { expr, array, .. } => {
383 collect_aggregates(expr, out);
384 collect_aggregates(array, out);
385 }
386 Expr::Case {
387 operand,
388 branches,
389 else_branch,
390 } => {
391 if let Some(o) = operand {
392 collect_aggregates(o, out);
393 }
394 for (w, t) in branches {
395 collect_aggregates(w, out);
396 collect_aggregates(t, out);
397 }
398 if let Some(e) = else_branch {
399 collect_aggregates(e, out);
400 }
401 }
402 }
403}
404
405fn update_state(st: &mut AggState, name: &str, v: &Value) -> Result<(), EvalError> {
406 let is_null = matches!(v, Value::Null);
407 match name {
408 "count_star" => st.count += 1,
409 "count" => {
410 if !is_null {
411 st.count += 1;
412 }
413 }
414 "sum" | "avg" => {
415 if is_null {
416 return Ok(());
417 }
418 st.count += 1;
419 match v {
420 Value::Int(n) => st.sum_int += i64::from(*n),
421 Value::BigInt(n) => st.sum_int += *n,
422 Value::Float(x) => {
423 st.use_float = true;
424 st.sum_float += *x;
425 }
426 other => {
427 return Err(EvalError::TypeMismatch {
428 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
429 });
430 }
431 }
432 }
433 "min" => {
434 if is_null {
435 return Ok(());
436 }
437 match &st.extreme {
438 None => st.extreme = Some(v.clone()),
439 Some(cur) => {
440 if value_cmp(v, cur) == core::cmp::Ordering::Less {
441 st.extreme = Some(v.clone());
442 }
443 }
444 }
445 }
446 "max" => {
447 if is_null {
448 return Ok(());
449 }
450 match &st.extreme {
451 None => st.extreme = Some(v.clone()),
452 Some(cur) => {
453 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
454 st.extreme = Some(v.clone());
455 }
456 }
457 }
458 }
459 _ => unreachable!("non-aggregate {name} in update_state"),
460 }
461 Ok(())
462}
463
464#[allow(clippy::cast_precision_loss)]
465fn finalize(name: &str, st: &AggState) -> Value {
466 match name {
467 "count" | "count_star" => Value::BigInt(st.count),
468 "sum" => {
469 if st.count == 0 {
470 Value::Null
471 } else if st.use_float {
472 Value::Float(st.sum_float + (st.sum_int as f64))
473 } else {
474 Value::BigInt(st.sum_int)
475 }
476 }
477 "avg" => {
478 if st.count == 0 {
479 Value::Null
480 } else {
481 let total = if st.use_float {
482 st.sum_float + (st.sum_int as f64)
483 } else {
484 st.sum_int as f64
485 };
486 Value::Float(total / (st.count as f64))
487 }
488 }
489 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
490 _ => unreachable!(),
491 }
492}
493
494fn infer_agg_type(spec: &AggSpec) -> DataType {
495 match spec.name.as_str() {
496 "count" | "count_star" | "sum" => DataType::BigInt,
500 "avg" => DataType::Float,
501 _ => DataType::Text,
504 }
505}
506
507fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
508 if let Expr::Column(c) = e
509 && let Some(s) = synth.iter().find(|s| s.name == c.name)
510 {
511 return s.ty;
512 }
513 DataType::Text
516}
517
518fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
519 if let Expr::FunctionCall { name, args } = e {
521 let lower = name.to_ascii_lowercase();
522 if is_aggregate_name(&lower) {
523 let arg = if lower == "count_star" {
524 None
525 } else {
526 args.first().cloned()
527 };
528 for (i, spec) in aggs.iter().enumerate() {
529 if spec.name == lower && spec.arg == arg {
530 return Expr::Column(spg_sql::ast::ColumnName {
531 qualifier: None,
532 name: format!("__agg_{i}"),
533 });
534 }
535 }
536 }
537 }
538 for (i, g) in group_exprs.iter().enumerate() {
540 if g == e {
541 return Expr::Column(spg_sql::ast::ColumnName {
542 qualifier: None,
543 name: format!("__grp_{i}"),
544 });
545 }
546 }
547 match e {
549 Expr::Binary { lhs, op, rhs } => Expr::Binary {
550 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
551 op: *op,
552 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
553 },
554 Expr::Unary { op, expr } => Expr::Unary {
555 op: *op,
556 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
557 },
558 Expr::Cast { expr, target } => Expr::Cast {
559 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
560 target: *target,
561 },
562 Expr::IsNull { expr, negated } => Expr::IsNull {
563 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
564 negated: *negated,
565 },
566 Expr::FunctionCall { name, args } => Expr::FunctionCall {
567 name: name.clone(),
568 args: args
569 .iter()
570 .map(|a| rewrite_expr(a, group_exprs, aggs))
571 .collect(),
572 },
573 Expr::Like {
574 expr,
575 pattern,
576 negated,
577 } => Expr::Like {
578 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
579 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
580 negated: *negated,
581 },
582 Expr::Extract { field, source } => Expr::Extract {
583 field: *field,
584 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
585 },
586 Expr::ScalarSubquery(_)
589 | Expr::Exists { .. }
590 | Expr::InSubquery { .. }
591 | Expr::WindowFunction { .. }
592 | Expr::Literal(_)
593 | Expr::Placeholder(_)
594 | Expr::Column(_) => e.clone(),
595 Expr::Array(items) => Expr::Array(
597 items
598 .iter()
599 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
600 .collect(),
601 ),
602 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
603 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
604 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
605 },
606 Expr::AnyAll {
607 expr,
608 op,
609 array,
610 is_any,
611 } => Expr::AnyAll {
612 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
613 op: *op,
614 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
615 is_any: *is_any,
616 },
617 Expr::Case {
618 operand,
619 branches,
620 else_branch,
621 } => Expr::Case {
622 operand: operand
623 .as_deref()
624 .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
625 branches: branches
626 .iter()
627 .map(|(w, t)| {
628 (
629 rewrite_expr(w, group_exprs, aggs),
630 rewrite_expr(t, group_exprs, aggs),
631 )
632 })
633 .collect(),
634 else_branch: else_branch
635 .as_deref()
636 .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
637 },
638 }
639}
640
641fn encode_key(vals: &[Value]) -> String {
643 let mut out = String::new();
644 for v in vals {
645 match v {
646 Value::Null => out.push_str("N|"),
647 Value::SmallInt(n) => {
648 out.push('s');
649 out.push_str(&n.to_string());
650 out.push('|');
651 }
652 Value::Int(n) => {
653 out.push('I');
654 out.push_str(&n.to_string());
655 out.push('|');
656 }
657 Value::BigInt(n) => {
658 out.push('B');
659 out.push_str(&n.to_string());
660 out.push('|');
661 }
662 Value::Float(x) => {
663 out.push('F');
664 out.push_str(&x.to_string());
665 out.push('|');
666 }
667 Value::Bool(b) => {
668 out.push(if *b { 'T' } else { 'f' });
669 out.push('|');
670 }
671 Value::Text(s) => {
672 out.push('S');
673 out.push_str(s);
674 out.push('|');
675 }
676 Value::Vector(v) => {
677 out.push('V');
678 for x in v {
679 out.push_str(&x.to_string());
680 out.push(',');
681 }
682 out.push('|');
683 }
684 Value::Sq8Vector(q) => {
690 out.push('Q');
691 out.push_str(&q.min.to_string());
692 out.push('@');
693 out.push_str(&q.max.to_string());
694 out.push(':');
695 for b in &q.bytes {
696 out.push_str(&b.to_string());
697 out.push(',');
698 }
699 out.push('|');
700 }
701 Value::HalfVector(h) => {
705 out.push('H');
706 for b in &h.bytes {
707 out.push_str(&b.to_string());
708 out.push(',');
709 }
710 out.push('|');
711 }
712 Value::Numeric { scaled, scale } => {
713 out.push('D');
714 out.push_str(&scaled.to_string());
715 out.push('@');
716 out.push_str(&scale.to_string());
717 out.push('|');
718 }
719 Value::Date(d) => {
720 out.push('d');
721 out.push_str(&d.to_string());
722 out.push('|');
723 }
724 Value::Timestamp(t) => {
725 out.push('t');
726 out.push_str(&t.to_string());
727 out.push('|');
728 }
729 Value::Interval { months, micros } => {
730 out.push('i');
731 out.push_str(&months.to_string());
732 out.push('m');
733 out.push_str(µs.to_string());
734 out.push('|');
735 }
736 Value::Json(s) => {
737 out.push('j');
738 out.push_str(s);
739 out.push('|');
740 }
741 _ => {
746 out.push('?');
747 out.push_str(&format!("{v:?}"));
748 out.push('|');
749 }
750 }
751 }
752 out
753}
754
755#[allow(clippy::cast_precision_loss)]
756fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
757 use core::cmp::Ordering::Equal;
758 match (a, b) {
759 (Value::Null, Value::Null) => Equal,
760 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
762 (Value::Int(x), Value::Int(y)) => x.cmp(y),
763 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
764 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
765 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
766 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
767 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
768 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
769 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
770 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
771 (Value::Text(x), Value::Text(y)) => x.cmp(y),
772 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
773 _ => Equal,
774 }
775}