1use alloc::boxed::Box;
24use alloc::collections::BTreeMap;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
65 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
66 contains_aggregate(expr)
67 }
68 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
69 Expr::Extract { source, .. } => contains_aggregate(source),
70 Expr::ScalarSubquery(_)
75 | Expr::Exists { .. }
76 | Expr::InSubquery { .. }
77 | Expr::WindowFunction { .. }
78 | Expr::Literal(_)
79 | Expr::Placeholder(_)
80 | Expr::Column(_) => false,
81 Expr::Array(items) => items.iter().any(contains_aggregate),
85 Expr::ArraySubscript { target, index } => {
86 contains_aggregate(target) || contains_aggregate(index)
87 }
88 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
89 }
90}
91
92pub fn is_aggregate_name(name: &str) -> bool {
93 matches!(
94 name.to_ascii_lowercase().as_str(),
95 "count" | "count_star" | "sum" | "min" | "max" | "avg"
96 )
97}
98
99#[derive(Debug, Default, Clone)]
101struct AggState {
102 count: i64,
103 sum_int: i64,
104 sum_float: f64,
105 extreme: Option<Value>,
106 use_float: bool,
107}
108
109#[derive(Debug, Clone)]
110struct AggSpec {
111 name: String, arg: Option<Expr>,
114}
115
116#[derive(Debug)]
119pub struct AggResult {
120 pub columns: Vec<ColumnSchema>,
121 pub rows: Vec<Row>,
122}
123
124#[allow(clippy::too_many_lines)]
127pub fn run(
128 stmt: &SelectStatement,
129 rows: &[&Row],
130 schema_cols: &[ColumnSchema],
131 table_alias: Option<&str>,
132) -> Result<AggResult, EvalError> {
133 let ctx = EvalContext::new(schema_cols, table_alias);
134 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
135
136 let mut agg_specs: Vec<AggSpec> = Vec::new();
138 for item in &stmt.items {
139 if let SelectItem::Expr { expr, .. } = item {
140 collect_aggregates(expr, &mut agg_specs);
141 }
142 }
143 for o in &stmt.order_by {
144 collect_aggregates(&o.expr, &mut agg_specs);
145 }
146 if let Some(h) = &stmt.having {
147 collect_aggregates(h, &mut agg_specs);
148 }
149
150 let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
153 let mut key_order: Vec<String> = Vec::new();
154 if rows.is_empty() && group_exprs.is_empty() {
157 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
159 groups.insert(String::new(), (Vec::new(), init));
160 key_order.push(String::new());
161 }
162
163 for row in rows {
164 let group_vals: Vec<Value> = group_exprs
165 .iter()
166 .map(|g| eval::eval_expr(g, row, &ctx))
167 .collect::<Result<_, _>>()?;
168 let key = encode_key(&group_vals);
169 let entry = groups.entry(key.clone()).or_insert_with(|| {
170 key_order.push(key.clone());
171 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
172 (group_vals.clone(), init)
173 });
174 for (i, spec) in agg_specs.iter().enumerate() {
175 let arg_val = match &spec.arg {
176 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
178 };
179 update_state(&mut entry.1[i], &spec.name, &arg_val)?;
180 }
181 }
182
183 let group_types: Vec<DataType> = if rows.is_empty() {
185 group_exprs.iter().map(|_| DataType::Text).collect()
188 } else {
189 let probe = rows[0];
190 group_exprs
191 .iter()
192 .map(|g| {
193 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
194 })
195 .collect::<Result<_, _>>()?
196 };
197 let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
198 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
199 for (i, ty) in group_types.iter().enumerate() {
200 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
201 }
202 for (i, ty) in agg_types.iter().enumerate() {
203 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
204 }
205
206 let mut synth_rows: Vec<Row> = Vec::new();
208 for k in &key_order {
209 let (gvals, states) = &groups[k];
210 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
211 values.extend(gvals.iter().cloned());
212 for (i, st) in states.iter().enumerate() {
213 values.push(finalize(&agg_specs[i].name, st));
214 }
215 synth_rows.push(Row::new(values));
216 }
217
218 let columns: Vec<ColumnSchema> = stmt
223 .items
224 .iter()
225 .map(|item| match item {
226 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
227 detail: "SELECT * with aggregates is not supported".into(),
228 }),
229 SelectItem::Expr { expr, alias } => {
230 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
231 let name = alias.clone().unwrap_or_else(|| expr.to_string());
232 Ok(ColumnSchema::new(
233 name,
234 agg_or_group_type(&rewritten, &synth_schema),
235 true,
236 ))
237 }
238 })
239 .collect::<Result<_, _>>()?;
240
241 let synth_ctx = EvalContext::new(&synth_schema, None);
246 let having_rewritten = stmt
247 .having
248 .as_ref()
249 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
250 let mut kept_synth: Vec<Row> = Vec::new();
251 let mut out_rows: Vec<Row> = Vec::new();
252 for srow in synth_rows {
253 if let Some(h) = &having_rewritten {
254 let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
255 if !matches!(cond, Value::Bool(true)) {
256 continue;
257 }
258 }
259 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
260 for item in &stmt.items {
261 if let SelectItem::Expr { expr, .. } = item {
262 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
263 values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
264 }
265 }
266 kept_synth.push(srow);
267 out_rows.push(Row::new(values));
268 }
269
270 if !stmt.order_by.is_empty() {
273 let rewritten: Vec<Expr> = stmt
276 .order_by
277 .iter()
278 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
279 .collect();
280 let descs: Vec<bool> = stmt.order_by.iter().map(|o| o.desc).collect();
281 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
282 .into_iter()
283 .zip(out_rows)
284 .map(|(s, o)| {
285 let mut keys = Vec::with_capacity(rewritten.len());
286 for e in &rewritten {
287 keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
288 }
289 Ok::<_, EvalError>((keys, o))
290 })
291 .collect::<Result<_, _>>()?;
292 tagged.sort_by(|a, b| {
293 use core::cmp::Ordering;
294 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
295 let cmp = value_cmp(ka, kb);
296 let cmp = if descs[i] { cmp.reverse() } else { cmp };
297 if cmp != Ordering::Equal {
298 return cmp;
299 }
300 }
301 Ordering::Equal
302 });
303 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
304 }
305
306 Ok(AggResult {
307 columns,
308 rows: out_rows,
309 })
310}
311
312fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
313 match e {
314 Expr::FunctionCall { name, args } => {
315 let lower = name.to_ascii_lowercase();
316 if is_aggregate_name(&lower) {
317 let arg = if lower == "count_star" {
318 None
319 } else {
320 args.first().cloned()
321 };
322 let spec = AggSpec {
323 name: lower,
324 arg: arg.clone(),
325 };
326 if !out.iter().any(|s| s.name == spec.name && s.arg == spec.arg) {
327 out.push(spec);
328 }
329 } else {
332 for a in args {
333 collect_aggregates(a, out);
334 }
335 }
336 }
337 Expr::Binary { lhs, rhs, .. } => {
338 collect_aggregates(lhs, out);
339 collect_aggregates(rhs, out);
340 }
341 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
342 collect_aggregates(expr, out);
343 }
344 Expr::Like { expr, pattern, .. } => {
345 collect_aggregates(expr, out);
346 collect_aggregates(pattern, out);
347 }
348 Expr::Extract { source, .. } => collect_aggregates(source, out),
349 Expr::ScalarSubquery(_)
352 | Expr::Exists { .. }
353 | Expr::InSubquery { .. }
354 | Expr::WindowFunction { .. }
355 | Expr::Literal(_)
356 | Expr::Placeholder(_)
357 | Expr::Column(_) => {}
358 Expr::Array(items) => {
361 for elem in items {
362 collect_aggregates(elem, out);
363 }
364 }
365 Expr::ArraySubscript { target, index } => {
366 collect_aggregates(target, out);
367 collect_aggregates(index, out);
368 }
369 Expr::AnyAll { expr, array, .. } => {
370 collect_aggregates(expr, out);
371 collect_aggregates(array, out);
372 }
373 }
374}
375
376fn update_state(st: &mut AggState, name: &str, v: &Value) -> Result<(), EvalError> {
377 let is_null = matches!(v, Value::Null);
378 match name {
379 "count_star" => st.count += 1,
380 "count" => {
381 if !is_null {
382 st.count += 1;
383 }
384 }
385 "sum" | "avg" => {
386 if is_null {
387 return Ok(());
388 }
389 st.count += 1;
390 match v {
391 Value::Int(n) => st.sum_int += i64::from(*n),
392 Value::BigInt(n) => st.sum_int += *n,
393 Value::Float(x) => {
394 st.use_float = true;
395 st.sum_float += *x;
396 }
397 other => {
398 return Err(EvalError::TypeMismatch {
399 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
400 });
401 }
402 }
403 }
404 "min" => {
405 if is_null {
406 return Ok(());
407 }
408 match &st.extreme {
409 None => st.extreme = Some(v.clone()),
410 Some(cur) => {
411 if value_cmp(v, cur) == core::cmp::Ordering::Less {
412 st.extreme = Some(v.clone());
413 }
414 }
415 }
416 }
417 "max" => {
418 if is_null {
419 return Ok(());
420 }
421 match &st.extreme {
422 None => st.extreme = Some(v.clone()),
423 Some(cur) => {
424 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
425 st.extreme = Some(v.clone());
426 }
427 }
428 }
429 }
430 _ => unreachable!("non-aggregate {name} in update_state"),
431 }
432 Ok(())
433}
434
435#[allow(clippy::cast_precision_loss)]
436fn finalize(name: &str, st: &AggState) -> Value {
437 match name {
438 "count" | "count_star" => Value::BigInt(st.count),
439 "sum" => {
440 if st.count == 0 {
441 Value::Null
442 } else if st.use_float {
443 Value::Float(st.sum_float + (st.sum_int as f64))
444 } else {
445 Value::BigInt(st.sum_int)
446 }
447 }
448 "avg" => {
449 if st.count == 0 {
450 Value::Null
451 } else {
452 let total = if st.use_float {
453 st.sum_float + (st.sum_int as f64)
454 } else {
455 st.sum_int as f64
456 };
457 Value::Float(total / (st.count as f64))
458 }
459 }
460 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
461 _ => unreachable!(),
462 }
463}
464
465fn infer_agg_type(spec: &AggSpec) -> DataType {
466 match spec.name.as_str() {
467 "count" | "count_star" | "sum" => DataType::BigInt,
471 "avg" => DataType::Float,
472 _ => DataType::Text,
475 }
476}
477
478fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
479 if let Expr::Column(c) = e
480 && let Some(s) = synth.iter().find(|s| s.name == c.name)
481 {
482 return s.ty;
483 }
484 DataType::Text
487}
488
489fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
490 if let Expr::FunctionCall { name, args } = e {
492 let lower = name.to_ascii_lowercase();
493 if is_aggregate_name(&lower) {
494 let arg = if lower == "count_star" {
495 None
496 } else {
497 args.first().cloned()
498 };
499 for (i, spec) in aggs.iter().enumerate() {
500 if spec.name == lower && spec.arg == arg {
501 return Expr::Column(spg_sql::ast::ColumnName {
502 qualifier: None,
503 name: format!("__agg_{i}"),
504 });
505 }
506 }
507 }
508 }
509 for (i, g) in group_exprs.iter().enumerate() {
511 if g == e {
512 return Expr::Column(spg_sql::ast::ColumnName {
513 qualifier: None,
514 name: format!("__grp_{i}"),
515 });
516 }
517 }
518 match e {
520 Expr::Binary { lhs, op, rhs } => Expr::Binary {
521 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
522 op: *op,
523 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
524 },
525 Expr::Unary { op, expr } => Expr::Unary {
526 op: *op,
527 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
528 },
529 Expr::Cast { expr, target } => Expr::Cast {
530 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
531 target: *target,
532 },
533 Expr::IsNull { expr, negated } => Expr::IsNull {
534 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
535 negated: *negated,
536 },
537 Expr::FunctionCall { name, args } => Expr::FunctionCall {
538 name: name.clone(),
539 args: args
540 .iter()
541 .map(|a| rewrite_expr(a, group_exprs, aggs))
542 .collect(),
543 },
544 Expr::Like {
545 expr,
546 pattern,
547 negated,
548 } => Expr::Like {
549 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
550 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
551 negated: *negated,
552 },
553 Expr::Extract { field, source } => Expr::Extract {
554 field: *field,
555 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
556 },
557 Expr::ScalarSubquery(_)
560 | Expr::Exists { .. }
561 | Expr::InSubquery { .. }
562 | Expr::WindowFunction { .. }
563 | Expr::Literal(_)
564 | Expr::Placeholder(_)
565 | Expr::Column(_) => e.clone(),
566 Expr::Array(items) => Expr::Array(
568 items
569 .iter()
570 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
571 .collect(),
572 ),
573 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
574 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
575 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
576 },
577 Expr::AnyAll {
578 expr,
579 op,
580 array,
581 is_any,
582 } => Expr::AnyAll {
583 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
584 op: *op,
585 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
586 is_any: *is_any,
587 },
588 }
589}
590
591fn encode_key(vals: &[Value]) -> String {
593 let mut out = String::new();
594 for v in vals {
595 match v {
596 Value::Null => out.push_str("N|"),
597 Value::SmallInt(n) => {
598 out.push('s');
599 out.push_str(&n.to_string());
600 out.push('|');
601 }
602 Value::Int(n) => {
603 out.push('I');
604 out.push_str(&n.to_string());
605 out.push('|');
606 }
607 Value::BigInt(n) => {
608 out.push('B');
609 out.push_str(&n.to_string());
610 out.push('|');
611 }
612 Value::Float(x) => {
613 out.push('F');
614 out.push_str(&x.to_string());
615 out.push('|');
616 }
617 Value::Bool(b) => {
618 out.push(if *b { 'T' } else { 'f' });
619 out.push('|');
620 }
621 Value::Text(s) => {
622 out.push('S');
623 out.push_str(s);
624 out.push('|');
625 }
626 Value::Vector(v) => {
627 out.push('V');
628 for x in v {
629 out.push_str(&x.to_string());
630 out.push(',');
631 }
632 out.push('|');
633 }
634 Value::Sq8Vector(q) => {
640 out.push('Q');
641 out.push_str(&q.min.to_string());
642 out.push('@');
643 out.push_str(&q.max.to_string());
644 out.push(':');
645 for b in &q.bytes {
646 out.push_str(&b.to_string());
647 out.push(',');
648 }
649 out.push('|');
650 }
651 Value::HalfVector(h) => {
655 out.push('H');
656 for b in &h.bytes {
657 out.push_str(&b.to_string());
658 out.push(',');
659 }
660 out.push('|');
661 }
662 Value::Numeric { scaled, scale } => {
663 out.push('D');
664 out.push_str(&scaled.to_string());
665 out.push('@');
666 out.push_str(&scale.to_string());
667 out.push('|');
668 }
669 Value::Date(d) => {
670 out.push('d');
671 out.push_str(&d.to_string());
672 out.push('|');
673 }
674 Value::Timestamp(t) => {
675 out.push('t');
676 out.push_str(&t.to_string());
677 out.push('|');
678 }
679 Value::Interval { months, micros } => {
680 out.push('i');
681 out.push_str(&months.to_string());
682 out.push('m');
683 out.push_str(µs.to_string());
684 out.push('|');
685 }
686 Value::Json(s) => {
687 out.push('j');
688 out.push_str(s);
689 out.push('|');
690 }
691 _ => {
696 out.push('?');
697 out.push_str(&format!("{v:?}"));
698 out.push('|');
699 }
700 }
701 }
702 out
703}
704
705#[allow(clippy::cast_precision_loss)]
706fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
707 use core::cmp::Ordering::Equal;
708 match (a, b) {
709 (Value::Null, Value::Null) => Equal,
710 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
712 (Value::Int(x), Value::Int(y)) => x.cmp(y),
713 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
714 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
715 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
716 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
717 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
718 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
719 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
720 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
721 (Value::Text(x), Value::Text(y)) => x.cmp(y),
722 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
723 _ => Equal,
724 }
725}