1use alloc::boxed::Box;
24use alloc::collections::{BTreeMap, BTreeSet};
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::AggregateOrdered { .. } => true,
65 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
66 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
67 contains_aggregate(expr)
68 }
69 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
70 Expr::Extract { source, .. } => contains_aggregate(source),
71 Expr::ScalarSubquery(_)
76 | Expr::Exists { .. }
77 | Expr::InSubquery { .. }
78 | Expr::WindowFunction { .. }
79 | Expr::Literal(_)
80 | Expr::Placeholder(_)
81 | Expr::Column(_) => false,
82 Expr::Array(items) => items.iter().any(contains_aggregate),
86 Expr::ArraySubscript { target, index } => {
87 contains_aggregate(target) || contains_aggregate(index)
88 }
89 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
90 Expr::Case {
93 operand,
94 branches,
95 else_branch,
96 } => {
97 operand.as_deref().is_some_and(contains_aggregate)
98 || branches
99 .iter()
100 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
101 || else_branch.as_deref().is_some_and(contains_aggregate)
102 }
103 }
104}
105
106pub fn is_aggregate_name(name: &str) -> bool {
107 matches!(
108 name.to_ascii_lowercase().as_str(),
109 "count"
110 | "count_star"
111 | "sum"
112 | "min"
113 | "max"
114 | "avg"
115 | "string_agg"
120 | "array_agg"
121 | "bool_and"
124 | "bool_or"
125 | "every"
126 )
127}
128
129#[derive(Debug, Default, Clone)]
131struct AggState {
132 count: i64,
133 sum_int: i64,
134 sum_float: f64,
135 extreme: Option<Value>,
136 use_float: bool,
137 items: Vec<Value>,
144 seen: BTreeSet<String>,
148 item_keys: Vec<Vec<Value>>,
152 separator: Option<String>,
158 bool_acc: Option<bool>,
162}
163
164#[derive(Debug, Clone)]
165struct AggSpec {
166 name: String, arg: Option<Expr>,
170 arg2: Option<Expr>,
176 distinct: bool,
179 order_by: Vec<spg_sql::ast::OrderBy>,
185}
186
187#[derive(Debug)]
190pub struct AggResult {
191 pub columns: Vec<ColumnSchema>,
192 pub rows: Vec<Row>,
193}
194
195#[allow(clippy::too_many_lines)]
198pub type CorrelatedEval<'a> = &'a dyn Fn(&Expr, &Row, &EvalContext<'_>) -> Result<Value, EvalError>;
205
206pub fn run(
207 stmt: &SelectStatement,
208 rows: &[&Row],
209 schema_cols: &[ColumnSchema],
210 table_alias: Option<&str>,
211 correlated_eval: Option<CorrelatedEval<'_>>,
212) -> Result<AggResult, EvalError> {
213 let ctx = EvalContext::new(schema_cols, table_alias);
214 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
215
216 let mut agg_specs: Vec<AggSpec> = Vec::new();
218 for item in &stmt.items {
219 if let SelectItem::Expr { expr, .. } = item {
220 collect_aggregates(expr, &mut agg_specs);
221 }
222 }
223 for o in &stmt.order_by {
224 collect_aggregates(&o.expr, &mut agg_specs);
225 }
226 if let Some(h) = &stmt.having {
227 collect_aggregates(h, &mut agg_specs);
228 }
229 validate_agg_arities(stmt, &agg_specs)?;
235
236 let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
239 let mut key_order: Vec<String> = Vec::new();
240 if rows.is_empty() && group_exprs.is_empty() {
243 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
245 groups.insert(String::new(), (Vec::new(), init));
246 key_order.push(String::new());
247 }
248
249 for row in rows {
250 let group_vals: Vec<Value> = group_exprs
251 .iter()
252 .map(|g| eval::eval_expr(g, row, &ctx))
253 .collect::<Result<_, _>>()?;
254 let mut key_vals = group_vals.clone();
260 for (i, g) in group_exprs.iter().enumerate() {
261 if matches!(
262 eval::column_collation(g, &ctx),
263 Some(spg_storage::Collation::CaseInsensitive)
264 ) {
265 if let Value::Text(s) = &key_vals[i] {
266 key_vals[i] = Value::Text(s.to_ascii_lowercase());
267 }
268 }
269 }
270 let key = encode_key(&key_vals);
271 let entry = groups.entry(key.clone()).or_insert_with(|| {
272 key_order.push(key.clone());
273 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
274 (group_vals.clone(), init)
275 });
276 for (i, spec) in agg_specs.iter().enumerate() {
277 let arg_val = match &spec.arg {
278 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
280 };
281 let arg2_val = match &spec.arg2 {
287 None => None,
288 Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
289 };
290 let order_keys = if spec.order_by.is_empty() {
293 None
294 } else {
295 let mut keys = Vec::with_capacity(spec.order_by.len());
296 for o in &spec.order_by {
297 keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
298 }
299 Some(keys)
300 };
301 if spec.distinct {
306 let key = encode_key(core::slice::from_ref(&arg_val));
307 if !entry.1[i].seen.insert(key) {
308 continue;
309 }
310 }
311 update_state(
312 &mut entry.1[i],
313 &spec.name,
314 &arg_val,
315 arg2_val.as_ref(),
316 order_keys,
317 )?;
318 }
319 }
320
321 let group_types: Vec<DataType> = if rows.is_empty() {
323 group_exprs.iter().map(|_| DataType::Text).collect()
326 } else {
327 let probe = rows[0];
328 group_exprs
329 .iter()
330 .map(|g| {
331 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
332 })
333 .collect::<Result<_, _>>()?
334 };
335 let agg_types: Vec<DataType> = agg_specs
336 .iter()
337 .map(|spec| infer_agg_type(spec, schema_cols))
338 .collect();
339 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
340 for (i, ty) in group_types.iter().enumerate() {
341 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
342 }
343 for (i, ty) in agg_types.iter().enumerate() {
344 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
345 }
346
347 let mut synth_rows: Vec<Row> = Vec::new();
349 for k in &key_order {
350 let (gvals, states) = &groups[k];
351 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
352 values.extend(gvals.iter().cloned());
353 for (i, st) in states.iter().enumerate() {
354 let st_sorted;
358 let st_final: &AggState =
359 if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
360 let mut idx: Vec<usize> = (0..st.items.len()).collect();
361 let ob = &agg_specs[i].order_by;
362 idx.sort_by(|&x, &y| {
363 for (k, o) in ob.iter().enumerate() {
364 let cmp = crate::order_by_value_cmp(
365 o.desc,
366 o.nulls_first,
367 &st.item_keys[x][k],
368 &st.item_keys[y][k],
369 );
370 if cmp != core::cmp::Ordering::Equal {
371 return cmp;
372 }
373 }
374 core::cmp::Ordering::Equal
375 });
376 let mut sorted = st.clone();
377 sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
378 st_sorted = sorted;
379 &st_sorted
380 } else {
381 st
382 };
383 values.push(finalize(&agg_specs[i].name, st_final));
384 }
385 synth_rows.push(Row::new(values));
386 }
387
388 let columns: Vec<ColumnSchema> = stmt
393 .items
394 .iter()
395 .map(|item| match item {
396 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
397 detail: "SELECT * with aggregates is not supported".into(),
398 }),
399 SelectItem::Expr { expr, alias } => {
400 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
401 let name = alias.clone().unwrap_or_else(|| expr.to_string());
402 Ok(ColumnSchema::new(
403 name,
404 agg_or_group_type(&rewritten, &synth_schema),
405 true,
406 ))
407 }
408 })
409 .collect::<Result<_, _>>()?;
410
411 let synth_ctx = EvalContext::new(&synth_schema, None);
416 let having_rewritten = stmt
417 .having
418 .as_ref()
419 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
420 let mut kept_synth: Vec<Row> = Vec::new();
421 let mut out_rows: Vec<Row> = Vec::new();
422 for srow in synth_rows {
423 if let Some(h) = &having_rewritten {
424 let cond = match correlated_eval {
425 Some(f) if crate::expr_has_subquery(h) => f(h, &srow, &synth_ctx)?,
426 _ => eval::eval_expr(h, &srow, &synth_ctx)?,
427 };
428 if !matches!(cond, Value::Bool(true)) {
429 continue;
430 }
431 }
432 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
433 for item in &stmt.items {
434 if let SelectItem::Expr { expr, .. } = item {
435 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
436 values.push(match correlated_eval {
437 Some(f) if crate::expr_has_subquery(&rewritten) => {
438 f(&rewritten, &srow, &synth_ctx)?
439 }
440 _ => eval::eval_expr(&rewritten, &srow, &synth_ctx)?,
441 });
442 }
443 }
444 kept_synth.push(srow);
445 out_rows.push(Row::new(values));
446 }
447
448 if !stmt.order_by.is_empty() {
451 let rewritten: Vec<Expr> = stmt
454 .order_by
455 .iter()
456 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
457 .collect();
458 let keys_meta: Vec<(bool, Option<bool>)> = stmt
459 .order_by
460 .iter()
461 .map(|o| (o.desc, o.nulls_first))
462 .collect();
463 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
464 .into_iter()
465 .zip(out_rows)
466 .map(|(s, o)| {
467 let mut keys = Vec::with_capacity(rewritten.len());
468 for e in &rewritten {
469 keys.push(match correlated_eval {
470 Some(f) if crate::expr_has_subquery(e) => f(e, &s, &synth_ctx)?,
471 _ => eval::eval_expr(e, &s, &synth_ctx)?,
472 });
473 }
474 Ok::<_, EvalError>((keys, o))
475 })
476 .collect::<Result<_, _>>()?;
477 tagged.sort_by(|a, b| {
478 use core::cmp::Ordering;
479 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
480 let (desc, nf) = keys_meta[i];
481 let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
482 if cmp != Ordering::Equal {
483 return cmp;
484 }
485 }
486 Ordering::Equal
487 });
488 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
489 }
490
491 Ok(AggResult {
492 columns,
493 rows: out_rows,
494 })
495}
496
497fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
503 fn walk(e: &Expr) -> Result<(), EvalError> {
504 if let Expr::FunctionCall { name, args } = e {
505 let lower = name.to_ascii_lowercase();
506 let expected: Option<usize> = match lower.as_str() {
507 "count_star" => Some(0),
508 "count" | "sum" | "avg" | "min" | "max" | "array_agg"
509 | "bool_and" | "bool_or" | "every" => Some(1),
513 "string_agg" => Some(2),
514 _ => None,
515 };
516 if let Some(want) = expected
517 && args.len() != want
518 {
519 return Err(EvalError::TypeMismatch {
520 detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
521 });
522 }
523 for a in args {
524 walk(a)?;
525 }
526 } else if let Expr::Binary { lhs, rhs, .. } = e {
527 walk(lhs)?;
528 walk(rhs)?;
529 } else if let Expr::Unary { expr, .. }
530 | Expr::Cast { expr, .. }
531 | Expr::IsNull { expr, .. } = e
532 {
533 walk(expr)?;
534 }
535 Ok(())
536 }
537 for item in &stmt.items {
538 if let SelectItem::Expr { expr, .. } = item {
539 walk(expr)?;
540 }
541 }
542 for o in &stmt.order_by {
543 walk(&o.expr)?;
544 }
545 if let Some(h) = &stmt.having {
546 walk(h)?;
547 }
548 Ok(())
549}
550
551fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
552 match e {
553 Expr::AggregateOrdered {
556 call,
557 order_by,
558 distinct,
559 } => {
560 if let Expr::FunctionCall { name, args } = call.as_ref() {
561 let lower = name.to_ascii_lowercase();
562 if is_aggregate_name(&lower) {
563 let canonical = if lower == "every" {
564 "bool_and".to_string()
565 } else {
566 lower
567 };
568 let spec = AggSpec {
569 name: canonical,
570 arg: args.first().cloned(),
571 arg2: if name.eq_ignore_ascii_case("string_agg") {
572 args.get(1).cloned()
573 } else {
574 None
575 },
576 distinct: *distinct,
577 order_by: order_by.clone(),
578 };
579 if !out.iter().any(|s| {
580 s.name == spec.name
581 && s.arg == spec.arg
582 && s.arg2 == spec.arg2
583 && s.distinct == spec.distinct
584 && s.order_by == spec.order_by
585 }) {
586 out.push(spec);
587 }
588 return;
589 }
590 }
591 collect_aggregates(call, out);
592 for o in order_by {
593 collect_aggregates(&o.expr, out);
594 }
595 }
596 Expr::FunctionCall { name, args } => {
597 let lower = name.to_ascii_lowercase();
598 if is_aggregate_name(&lower) {
599 let arg = if lower == "count_star" {
600 None
601 } else {
602 args.first().cloned()
603 };
604 let arg2 = if lower == "string_agg" {
608 args.get(1).cloned()
609 } else {
610 None
611 };
612 let canonical = if lower == "every" {
616 "bool_and".to_string()
617 } else {
618 lower
619 };
620 let spec = AggSpec {
621 name: canonical,
622 arg: arg.clone(),
623 arg2: arg2.clone(),
624 distinct: false,
625 order_by: Vec::new(),
626 };
627 if !out.iter().any(|s| {
628 s.name == spec.name
629 && s.arg == spec.arg
630 && s.arg2 == spec.arg2
631 && !s.distinct
632 && s.order_by == spec.order_by
633 }) {
634 out.push(spec);
635 }
636 } else {
639 for a in args {
640 collect_aggregates(a, out);
641 }
642 }
643 }
644 Expr::Binary { lhs, rhs, .. } => {
645 collect_aggregates(lhs, out);
646 collect_aggregates(rhs, out);
647 }
648 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
649 collect_aggregates(expr, out);
650 }
651 Expr::Like { expr, pattern, .. } => {
652 collect_aggregates(expr, out);
653 collect_aggregates(pattern, out);
654 }
655 Expr::Extract { source, .. } => collect_aggregates(source, out),
656 Expr::ScalarSubquery(_)
659 | Expr::Exists { .. }
660 | Expr::InSubquery { .. }
661 | Expr::WindowFunction { .. }
662 | Expr::Literal(_)
663 | Expr::Placeholder(_)
664 | Expr::Column(_) => {}
665 Expr::Array(items) => {
668 for elem in items {
669 collect_aggregates(elem, out);
670 }
671 }
672 Expr::ArraySubscript { target, index } => {
673 collect_aggregates(target, out);
674 collect_aggregates(index, out);
675 }
676 Expr::AnyAll { expr, array, .. } => {
677 collect_aggregates(expr, out);
678 collect_aggregates(array, out);
679 }
680 Expr::Case {
681 operand,
682 branches,
683 else_branch,
684 } => {
685 if let Some(o) = operand {
686 collect_aggregates(o, out);
687 }
688 for (w, t) in branches {
689 collect_aggregates(w, out);
690 collect_aggregates(t, out);
691 }
692 if let Some(e) = else_branch {
693 collect_aggregates(e, out);
694 }
695 }
696 }
697}
698
699fn update_state(
700 st: &mut AggState,
701 name: &str,
702 v: &Value,
703 arg2: Option<&Value>,
704 order_keys: Option<Vec<Value>>,
705) -> Result<(), EvalError> {
706 let is_null = matches!(v, Value::Null);
707 match name {
708 "count_star" => st.count += 1,
709 "count" => {
710 if !is_null {
711 st.count += 1;
712 }
713 }
714 "sum" | "avg" => {
715 if is_null {
716 return Ok(());
717 }
718 st.count += 1;
719 match v {
720 Value::Int(n) => st.sum_int += i64::from(*n),
721 Value::BigInt(n) => st.sum_int += *n,
722 Value::Float(x) => {
723 st.use_float = true;
724 st.sum_float += *x;
725 }
726 other => {
727 return Err(EvalError::TypeMismatch {
728 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
729 });
730 }
731 }
732 }
733 "min" => {
734 if is_null {
735 return Ok(());
736 }
737 match &st.extreme {
738 None => st.extreme = Some(v.clone()),
739 Some(cur) => {
740 if value_cmp(v, cur) == core::cmp::Ordering::Less {
741 st.extreme = Some(v.clone());
742 }
743 }
744 }
745 }
746 "max" => {
747 if is_null {
748 return Ok(());
749 }
750 match &st.extreme {
751 None => st.extreme = Some(v.clone()),
752 Some(cur) => {
753 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
754 st.extreme = Some(v.clone());
755 }
756 }
757 }
758 }
759 "string_agg" => {
767 if let Some(sep) = arg2
768 && let Value::Text(s) = sep
769 {
770 st.separator = Some(s.clone());
771 }
772 if is_null {
773 return Ok(());
774 }
775 if let Value::Text(s) = v {
776 st.items.push(Value::Text(s.clone()));
777 if let Some(k) = order_keys {
778 st.item_keys.push(k);
779 }
780 st.count += 1;
781 } else {
782 return Err(EvalError::TypeMismatch {
783 detail: format!("string_agg requires text value, got {:?}", v.data_type()),
784 });
785 }
786 }
787 "array_agg" => {
793 st.items.push(v.clone());
794 if let Some(k) = order_keys {
795 st.item_keys.push(k);
796 }
797 st.count += 1;
798 }
799 "bool_and" => {
803 if is_null {
804 return Ok(());
805 }
806 let b = match v {
807 Value::Bool(b) => *b,
808 other => {
809 return Err(EvalError::TypeMismatch {
810 detail: format!("bool_and requires bool, got {:?}", other.data_type()),
811 });
812 }
813 };
814 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
815 }
816 "bool_or" => {
819 if is_null {
820 return Ok(());
821 }
822 let b = match v {
823 Value::Bool(b) => *b,
824 other => {
825 return Err(EvalError::TypeMismatch {
826 detail: format!("bool_or requires bool, got {:?}", other.data_type()),
827 });
828 }
829 };
830 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
831 }
832 _ => unreachable!("non-aggregate {name} in update_state"),
833 }
834 Ok(())
835}
836
837#[allow(clippy::cast_precision_loss)]
838fn finalize(name: &str, st: &AggState) -> Value {
839 match name {
840 "count" | "count_star" => Value::BigInt(st.count),
841 "sum" => {
842 if st.count == 0 {
843 Value::Null
844 } else if st.use_float {
845 Value::Float(st.sum_float + (st.sum_int as f64))
846 } else {
847 Value::BigInt(st.sum_int)
848 }
849 }
850 "avg" => {
851 if st.count == 0 {
852 Value::Null
853 } else {
854 let total = if st.use_float {
855 st.sum_float + (st.sum_int as f64)
856 } else {
857 st.sum_int as f64
858 };
859 Value::Float(total / (st.count as f64))
860 }
861 }
862 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
863 "string_agg" => {
867 if st.items.is_empty() {
868 return Value::Null;
869 }
870 let sep = st.separator.clone().unwrap_or_default();
871 let mut out = String::new();
872 for (i, item) in st.items.iter().enumerate() {
873 if i > 0 {
874 out.push_str(&sep);
875 }
876 if let Value::Text(s) = item {
877 out.push_str(s);
878 }
879 }
880 Value::Text(out)
881 }
882 "array_agg" => {
889 if st.items.is_empty() {
890 return Value::Null;
891 }
892 let probe = st.items.iter().find(|v| !v.is_null());
893 match probe.and_then(spg_storage::Value::data_type) {
894 Some(DataType::Int) | Some(DataType::SmallInt) => {
895 let items: Vec<Option<i32>> = st
896 .items
897 .iter()
898 .map(|v| match v {
899 Value::Int(n) => Some(*n),
900 Value::SmallInt(n) => Some(i32::from(*n)),
901 _ => None,
902 })
903 .collect();
904 Value::IntArray(items)
905 }
906 Some(DataType::BigInt) => {
907 let items: Vec<Option<i64>> = st
908 .items
909 .iter()
910 .map(|v| match v {
911 Value::BigInt(n) => Some(*n),
912 _ => None,
913 })
914 .collect();
915 Value::BigIntArray(items)
916 }
917 _ => {
918 let items: Vec<Option<String>> = st
919 .items
920 .iter()
921 .map(|v| match v {
922 Value::Text(s) => Some(s.clone()),
923 Value::Null => None,
924 other => Some(format!("{other:?}")),
925 })
926 .collect();
927 Value::TextArray(items)
928 }
929 }
930 }
931 "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
935 _ => unreachable!(),
936 }
937}
938
939fn infer_agg_type(spec: &AggSpec, schema_cols: &[ColumnSchema]) -> DataType {
940 let arg_ty = spec
944 .arg
945 .as_ref()
946 .and_then(|a| crate::describe::describe_expr(a, schema_cols))
947 .map(|shape| shape.ty);
948 match spec.name.as_str() {
949 "count" | "count_star" => DataType::BigInt,
950 "sum" => match arg_ty {
951 Some(DataType::Float) => DataType::Float,
952 _ => DataType::BigInt,
953 },
954 "avg" => DataType::Float,
955 "string_agg" => DataType::Text,
957 "array_agg" => match arg_ty {
958 Some(DataType::Int | DataType::SmallInt) => DataType::IntArray,
959 Some(DataType::BigInt) => DataType::BigIntArray,
960 _ => DataType::TextArray,
961 },
962 "bool_and" | "bool_or" => DataType::Bool,
965 _ => arg_ty.unwrap_or(DataType::Text),
967 }
968}
969
970fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
971 if let Expr::Column(c) = e
972 && let Some(s) = synth.iter().find(|s| s.name == c.name)
973 {
974 return s.ty;
975 }
976 crate::describe::describe_expr(e, synth)
982 .map(|shape| shape.ty)
983 .unwrap_or(DataType::Text)
984}
985
986fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
987 if let Expr::AggregateOrdered {
990 call,
991 order_by,
992 distinct,
993 } = e
994 && let Expr::FunctionCall { name, args } = call.as_ref()
995 {
996 let lower = name.to_ascii_lowercase();
997 if is_aggregate_name(&lower) {
998 let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
999 let arg = args.first().cloned();
1000 let arg2 = if lower == "string_agg" {
1001 args.get(1).cloned()
1002 } else {
1003 None
1004 };
1005 for (i, spec) in aggs.iter().enumerate() {
1006 if spec.name == canonical
1007 && spec.arg == arg
1008 && spec.arg2 == arg2
1009 && spec.distinct == *distinct
1010 && spec.order_by == *order_by
1011 {
1012 return Expr::Column(spg_sql::ast::ColumnName {
1013 qualifier: None,
1014 name: format!("__agg_{i}"),
1015 });
1016 }
1017 }
1018 }
1019 }
1020 if let Expr::FunctionCall { name, args } = e {
1022 let lower = name.to_ascii_lowercase();
1023 if is_aggregate_name(&lower) {
1024 let arg = if lower == "count_star" {
1025 None
1026 } else {
1027 args.first().cloned()
1028 };
1029 let arg2 = if lower == "string_agg" {
1032 args.get(1).cloned()
1033 } else {
1034 None
1035 };
1036 let canonical: &str = if lower == "every" {
1040 "bool_and"
1041 } else {
1042 lower.as_str()
1043 };
1044 for (i, spec) in aggs.iter().enumerate() {
1045 if spec.name == canonical
1046 && spec.arg == arg
1047 && spec.arg2 == arg2
1048 && !spec.distinct
1049 && spec.order_by.is_empty()
1050 {
1051 return Expr::Column(spg_sql::ast::ColumnName {
1052 qualifier: None,
1053 name: format!("__agg_{i}"),
1054 });
1055 }
1056 }
1057 }
1058 }
1059 for (i, g) in group_exprs.iter().enumerate() {
1061 if g == e {
1062 return Expr::Column(spg_sql::ast::ColumnName {
1063 qualifier: None,
1064 name: format!("__grp_{i}"),
1065 });
1066 }
1067 }
1068 match e {
1070 Expr::AggregateOrdered {
1071 call,
1072 order_by,
1073 distinct,
1074 } => Expr::AggregateOrdered {
1075 call: Box::new(rewrite_expr(call, group_exprs, aggs)),
1076 distinct: *distinct,
1077 order_by: order_by
1078 .iter()
1079 .map(|o| spg_sql::ast::OrderBy {
1080 expr: rewrite_expr(&o.expr, group_exprs, aggs),
1081 desc: o.desc,
1082 nulls_first: o.nulls_first,
1083 })
1084 .collect(),
1085 },
1086 Expr::Binary { lhs, op, rhs } => Expr::Binary {
1087 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
1088 op: *op,
1089 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
1090 },
1091 Expr::Unary { op, expr } => Expr::Unary {
1092 op: *op,
1093 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1094 },
1095 Expr::Cast { expr, target } => Expr::Cast {
1096 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1097 target: *target,
1098 },
1099 Expr::IsNull { expr, negated } => Expr::IsNull {
1100 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1101 negated: *negated,
1102 },
1103 Expr::FunctionCall { name, args } => Expr::FunctionCall {
1104 name: name.clone(),
1105 args: args
1106 .iter()
1107 .map(|a| rewrite_expr(a, group_exprs, aggs))
1108 .collect(),
1109 },
1110 Expr::Like {
1111 expr,
1112 pattern,
1113 negated,
1114 case_insensitive,
1115 } => Expr::Like {
1116 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1117 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
1118 negated: *negated,
1119 case_insensitive: *case_insensitive,
1120 },
1121 Expr::Extract { field, source } => Expr::Extract {
1122 field: *field,
1123 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
1124 },
1125 Expr::ScalarSubquery(s) => {
1131 Expr::ScalarSubquery(Box::new(rewrite_group_keys_in_select(s, group_exprs)))
1132 }
1133 Expr::Exists { subquery, negated } => Expr::Exists {
1134 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1135 negated: *negated,
1136 },
1137 Expr::InSubquery {
1138 expr,
1139 subquery,
1140 negated,
1141 } => Expr::InSubquery {
1142 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1143 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1144 negated: *negated,
1145 },
1146 Expr::WindowFunction { .. } | Expr::Literal(_) | Expr::Placeholder(_) | Expr::Column(_) => {
1149 e.clone()
1150 }
1151 Expr::Array(items) => Expr::Array(
1153 items
1154 .iter()
1155 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
1156 .collect(),
1157 ),
1158 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
1159 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
1160 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
1161 },
1162 Expr::AnyAll {
1163 expr,
1164 op,
1165 array,
1166 is_any,
1167 } => Expr::AnyAll {
1168 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1169 op: *op,
1170 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
1171 is_any: *is_any,
1172 },
1173 Expr::Case {
1174 operand,
1175 branches,
1176 else_branch,
1177 } => Expr::Case {
1178 operand: operand
1179 .as_deref()
1180 .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
1181 branches: branches
1182 .iter()
1183 .map(|(w, t)| {
1184 (
1185 rewrite_expr(w, group_exprs, aggs),
1186 rewrite_expr(t, group_exprs, aggs),
1187 )
1188 })
1189 .collect(),
1190 else_branch: else_branch
1191 .as_deref()
1192 .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
1193 },
1194 }
1195}
1196
1197fn rewrite_group_keys_in_select(
1202 s: &spg_sql::ast::SelectStatement,
1203 group_exprs: &[Expr],
1204) -> spg_sql::ast::SelectStatement {
1205 let mut out = s.clone();
1206 let _ = crate::walk_select_exprs_mut(&mut out, &mut |e| {
1207 *e = rewrite_expr(e, group_exprs, &[]);
1208 Ok(())
1209 });
1210 out
1211}
1212
1213fn encode_key(vals: &[Value]) -> String {
1215 let mut out = String::new();
1216 for v in vals {
1217 match v {
1218 Value::Null => out.push_str("N|"),
1219 Value::SmallInt(n) => {
1220 out.push('s');
1221 out.push_str(&n.to_string());
1222 out.push('|');
1223 }
1224 Value::Int(n) => {
1225 out.push('I');
1226 out.push_str(&n.to_string());
1227 out.push('|');
1228 }
1229 Value::BigInt(n) => {
1230 out.push('B');
1231 out.push_str(&n.to_string());
1232 out.push('|');
1233 }
1234 Value::Float(x) => {
1235 out.push('F');
1236 out.push_str(&x.to_string());
1237 out.push('|');
1238 }
1239 Value::Bool(b) => {
1240 out.push(if *b { 'T' } else { 'f' });
1241 out.push('|');
1242 }
1243 Value::Text(s) => {
1244 out.push('S');
1245 out.push_str(s);
1246 out.push('|');
1247 }
1248 Value::Vector(v) => {
1249 out.push('V');
1250 for x in v {
1251 out.push_str(&x.to_string());
1252 out.push(',');
1253 }
1254 out.push('|');
1255 }
1256 Value::Sq8Vector(q) => {
1262 out.push('Q');
1263 out.push_str(&q.min.to_string());
1264 out.push('@');
1265 out.push_str(&q.max.to_string());
1266 out.push(':');
1267 for b in &q.bytes {
1268 out.push_str(&b.to_string());
1269 out.push(',');
1270 }
1271 out.push('|');
1272 }
1273 Value::HalfVector(h) => {
1277 out.push('H');
1278 for b in &h.bytes {
1279 out.push_str(&b.to_string());
1280 out.push(',');
1281 }
1282 out.push('|');
1283 }
1284 Value::Numeric { scaled, scale } => {
1285 out.push('D');
1286 out.push_str(&scaled.to_string());
1287 out.push('@');
1288 out.push_str(&scale.to_string());
1289 out.push('|');
1290 }
1291 Value::Date(d) => {
1292 out.push('d');
1293 out.push_str(&d.to_string());
1294 out.push('|');
1295 }
1296 Value::Timestamp(t) => {
1297 out.push('t');
1298 out.push_str(&t.to_string());
1299 out.push('|');
1300 }
1301 Value::Interval { months, micros } => {
1302 out.push('i');
1303 out.push_str(&months.to_string());
1304 out.push('m');
1305 out.push_str(µs.to_string());
1306 out.push('|');
1307 }
1308 Value::Json(s) => {
1309 out.push('j');
1310 out.push_str(s);
1311 out.push('|');
1312 }
1313 _ => {
1318 out.push('?');
1319 out.push_str(&format!("{v:?}"));
1320 out.push('|');
1321 }
1322 }
1323 }
1324 out
1325}
1326
1327#[allow(clippy::cast_precision_loss)]
1328fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1329 use core::cmp::Ordering::Equal;
1330 match (a, b) {
1331 (Value::Null, Value::Null) => Equal,
1332 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
1334 (Value::Int(x), Value::Int(y)) => x.cmp(y),
1335 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1336 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1337 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1338 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1339 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1340 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1341 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1342 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1343 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1344 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1345 _ => Equal,
1346 }
1347}