1use alloc::boxed::Box;
24use alloc::collections::BTreeSet;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::AggregateOrdered { .. } => true,
65 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
66 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
67 contains_aggregate(expr)
68 }
69 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
70 Expr::Extract { source, .. } => contains_aggregate(source),
71 Expr::ScalarSubquery(_)
76 | Expr::Exists { .. }
77 | Expr::InSubquery { .. }
78 | Expr::WindowFunction { .. }
79 | Expr::Literal(_)
80 | Expr::Placeholder(_)
81 | Expr::Column(_) => false,
82 Expr::Array(items) => items.iter().any(contains_aggregate),
86 Expr::ArraySubscript { target, index } => {
87 contains_aggregate(target) || contains_aggregate(index)
88 }
89 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
90 Expr::Case {
93 operand,
94 branches,
95 else_branch,
96 } => {
97 operand.as_deref().is_some_and(contains_aggregate)
98 || branches
99 .iter()
100 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
101 || else_branch.as_deref().is_some_and(contains_aggregate)
102 }
103 }
104}
105
106pub fn is_aggregate_name(name: &str) -> bool {
107 matches!(
108 name.to_ascii_lowercase().as_str(),
109 "count"
110 | "count_star"
111 | "sum"
112 | "min"
113 | "max"
114 | "avg"
115 | "string_agg"
120 | "array_agg"
121 | "bool_and"
124 | "bool_or"
125 | "every"
126 )
127}
128
129#[derive(Debug, Default, Clone)]
131struct AggState {
132 count: i64,
133 sum_int: i64,
134 sum_float: f64,
135 extreme: Option<Value>,
136 use_float: bool,
137 items: Vec<Value>,
144 seen: BTreeSet<String>,
148 item_keys: Vec<Vec<Value>>,
152 separator: Option<String>,
158 bool_acc: Option<bool>,
162}
163
164#[derive(Debug, Clone)]
165struct AggSpec {
166 name: String, arg: Option<Expr>,
170 arg2: Option<Expr>,
176 distinct: bool,
179 order_by: Vec<spg_sql::ast::OrderBy>,
185}
186
187#[derive(Debug)]
190pub struct AggResult {
191 pub columns: Vec<ColumnSchema>,
192 pub rows: Vec<Row>,
193}
194
195#[allow(clippy::too_many_lines)]
198pub type CorrelatedEval<'a> = &'a dyn Fn(&Expr, &Row, &EvalContext<'_>) -> Result<Value, EvalError>;
205
206pub fn run(
207 stmt: &SelectStatement,
208 rows: &[&Row],
209 schema_cols: &[ColumnSchema],
210 table_alias: Option<&str>,
211 correlated_eval: Option<CorrelatedEval<'_>>,
212) -> Result<AggResult, EvalError> {
213 let ctx = EvalContext::new(schema_cols, table_alias);
214 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
215
216 let mut agg_specs: Vec<AggSpec> = Vec::new();
218 for item in &stmt.items {
219 if let SelectItem::Expr { expr, .. } = item {
220 collect_aggregates(expr, &mut agg_specs);
221 }
222 }
223 for o in &stmt.order_by {
224 collect_aggregates(&o.expr, &mut agg_specs);
225 }
226 if let Some(h) = &stmt.having {
227 collect_aggregates(h, &mut agg_specs);
228 }
229 validate_agg_arities(stmt, &agg_specs)?;
235
236 let mut groups: hashbrown::HashMap<String, (Vec<Value>, Vec<AggState>)> =
240 hashbrown::HashMap::new();
241 let mut key_order: Vec<String> = Vec::new();
242 if rows.is_empty() && group_exprs.is_empty() {
245 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
247 groups.insert(String::new(), (Vec::new(), init));
248 key_order.push(String::new());
249 }
250
251 for row in rows {
252 let group_vals: Vec<Value> = group_exprs
253 .iter()
254 .map(|g| eval::eval_expr(g, row, &ctx))
255 .collect::<Result<_, _>>()?;
256 let mut key_vals = group_vals.clone();
262 for (i, g) in group_exprs.iter().enumerate() {
263 if matches!(
264 eval::column_collation(g, &ctx),
265 Some(spg_storage::Collation::CaseInsensitive)
266 ) {
267 if let Value::Text(s) = &key_vals[i] {
268 key_vals[i] = Value::Text(s.to_ascii_lowercase());
269 }
270 }
271 }
272 let key = encode_key(&key_vals);
273 let entry = groups.entry(key.clone()).or_insert_with(|| {
274 key_order.push(key.clone());
275 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
276 (group_vals.clone(), init)
277 });
278 for (i, spec) in agg_specs.iter().enumerate() {
279 let arg_val = match &spec.arg {
280 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
282 };
283 let arg2_val = match &spec.arg2 {
289 None => None,
290 Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
291 };
292 let order_keys = if spec.order_by.is_empty() {
295 None
296 } else {
297 let mut keys = Vec::with_capacity(spec.order_by.len());
298 for o in &spec.order_by {
299 keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
300 }
301 Some(keys)
302 };
303 if spec.distinct {
308 let key = encode_key(core::slice::from_ref(&arg_val));
309 if !entry.1[i].seen.insert(key) {
310 continue;
311 }
312 }
313 update_state(
314 &mut entry.1[i],
315 &spec.name,
316 &arg_val,
317 arg2_val.as_ref(),
318 order_keys,
319 )?;
320 }
321 }
322
323 let group_types: Vec<DataType> = if rows.is_empty() {
325 group_exprs.iter().map(|_| DataType::Text).collect()
328 } else {
329 let probe = rows[0];
330 group_exprs
331 .iter()
332 .map(|g| {
333 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
334 })
335 .collect::<Result<_, _>>()?
336 };
337 let agg_types: Vec<DataType> = agg_specs
338 .iter()
339 .map(|spec| infer_agg_type(spec, schema_cols))
340 .collect();
341 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
342 for (i, ty) in group_types.iter().enumerate() {
343 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
344 }
345 for (i, ty) in agg_types.iter().enumerate() {
346 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
347 }
348
349 let mut synth_rows: Vec<Row> = Vec::new();
351 for k in &key_order {
352 let (gvals, states) = &groups[k];
353 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
354 values.extend(gvals.iter().cloned());
355 for (i, st) in states.iter().enumerate() {
356 let st_sorted;
360 let st_final: &AggState =
361 if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
362 let mut idx: Vec<usize> = (0..st.items.len()).collect();
363 let ob = &agg_specs[i].order_by;
364 idx.sort_by(|&x, &y| {
365 for (k, o) in ob.iter().enumerate() {
366 let cmp = crate::order_by_value_cmp(
367 o.desc,
368 o.nulls_first,
369 &st.item_keys[x][k],
370 &st.item_keys[y][k],
371 );
372 if cmp != core::cmp::Ordering::Equal {
373 return cmp;
374 }
375 }
376 core::cmp::Ordering::Equal
377 });
378 let mut sorted = st.clone();
379 sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
380 st_sorted = sorted;
381 &st_sorted
382 } else {
383 st
384 };
385 values.push(finalize(&agg_specs[i].name, st_final));
386 }
387 synth_rows.push(Row::new(values));
388 }
389
390 let columns: Vec<ColumnSchema> = stmt
395 .items
396 .iter()
397 .map(|item| match item {
398 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
399 detail: "SELECT * with aggregates is not supported".into(),
400 }),
401 SelectItem::Expr { expr, alias } => {
402 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
403 let name = alias.clone().unwrap_or_else(|| expr.to_string());
404 Ok(ColumnSchema::new(
405 name,
406 agg_or_group_type(&rewritten, &synth_schema),
407 true,
408 ))
409 }
410 })
411 .collect::<Result<_, _>>()?;
412
413 let synth_ctx = EvalContext::new(&synth_schema, None);
418 let having_rewritten = stmt
419 .having
420 .as_ref()
421 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
422 let mut kept_synth: Vec<Row> = Vec::new();
423 let mut out_rows: Vec<Row> = Vec::new();
424 for srow in synth_rows {
425 if let Some(h) = &having_rewritten {
426 let cond = match correlated_eval {
427 Some(f) if crate::expr_has_subquery(h) => f(h, &srow, &synth_ctx)?,
428 _ => eval::eval_expr(h, &srow, &synth_ctx)?,
429 };
430 if !matches!(cond, Value::Bool(true)) {
431 continue;
432 }
433 }
434 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
435 for item in &stmt.items {
436 if let SelectItem::Expr { expr, .. } = item {
437 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
438 values.push(match correlated_eval {
439 Some(f) if crate::expr_has_subquery(&rewritten) => {
440 f(&rewritten, &srow, &synth_ctx)?
441 }
442 _ => eval::eval_expr(&rewritten, &srow, &synth_ctx)?,
443 });
444 }
445 }
446 kept_synth.push(srow);
447 out_rows.push(Row::new(values));
448 }
449
450 if !stmt.order_by.is_empty() {
453 let rewritten: Vec<Expr> = stmt
456 .order_by
457 .iter()
458 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
459 .collect();
460 let keys_meta: Vec<(bool, Option<bool>)> = stmt
461 .order_by
462 .iter()
463 .map(|o| (o.desc, o.nulls_first))
464 .collect();
465 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
466 .into_iter()
467 .zip(out_rows)
468 .map(|(s, o)| {
469 let mut keys = Vec::with_capacity(rewritten.len());
470 for e in &rewritten {
471 keys.push(match correlated_eval {
472 Some(f) if crate::expr_has_subquery(e) => f(e, &s, &synth_ctx)?,
473 _ => eval::eval_expr(e, &s, &synth_ctx)?,
474 });
475 }
476 Ok::<_, EvalError>((keys, o))
477 })
478 .collect::<Result<_, _>>()?;
479 tagged.sort_by(|a, b| {
480 use core::cmp::Ordering;
481 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
482 let (desc, nf) = keys_meta[i];
483 let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
484 if cmp != Ordering::Equal {
485 return cmp;
486 }
487 }
488 Ordering::Equal
489 });
490 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
491 }
492
493 Ok(AggResult {
494 columns,
495 rows: out_rows,
496 })
497}
498
499fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
505 fn walk(e: &Expr) -> Result<(), EvalError> {
506 if let Expr::FunctionCall { name, args } = e {
507 let lower = name.to_ascii_lowercase();
508 let expected: Option<usize> = match lower.as_str() {
509 "count_star" => Some(0),
510 "count" | "sum" | "avg" | "min" | "max" | "array_agg"
511 | "bool_and" | "bool_or" | "every" => Some(1),
515 "string_agg" => Some(2),
516 _ => None,
517 };
518 if let Some(want) = expected
519 && args.len() != want
520 {
521 return Err(EvalError::TypeMismatch {
522 detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
523 });
524 }
525 for a in args {
526 walk(a)?;
527 }
528 } else if let Expr::Binary { lhs, rhs, .. } = e {
529 walk(lhs)?;
530 walk(rhs)?;
531 } else if let Expr::Unary { expr, .. }
532 | Expr::Cast { expr, .. }
533 | Expr::IsNull { expr, .. } = e
534 {
535 walk(expr)?;
536 }
537 Ok(())
538 }
539 for item in &stmt.items {
540 if let SelectItem::Expr { expr, .. } = item {
541 walk(expr)?;
542 }
543 }
544 for o in &stmt.order_by {
545 walk(&o.expr)?;
546 }
547 if let Some(h) = &stmt.having {
548 walk(h)?;
549 }
550 Ok(())
551}
552
553fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
554 match e {
555 Expr::AggregateOrdered {
558 call,
559 order_by,
560 distinct,
561 } => {
562 if let Expr::FunctionCall { name, args } = call.as_ref() {
563 let lower = name.to_ascii_lowercase();
564 if is_aggregate_name(&lower) {
565 let canonical = if lower == "every" {
566 "bool_and".to_string()
567 } else {
568 lower
569 };
570 let spec = AggSpec {
571 name: canonical,
572 arg: args.first().cloned(),
573 arg2: if name.eq_ignore_ascii_case("string_agg") {
574 args.get(1).cloned()
575 } else {
576 None
577 },
578 distinct: *distinct,
579 order_by: order_by.clone(),
580 };
581 if !out.iter().any(|s| {
582 s.name == spec.name
583 && s.arg == spec.arg
584 && s.arg2 == spec.arg2
585 && s.distinct == spec.distinct
586 && s.order_by == spec.order_by
587 }) {
588 out.push(spec);
589 }
590 return;
591 }
592 }
593 collect_aggregates(call, out);
594 for o in order_by {
595 collect_aggregates(&o.expr, out);
596 }
597 }
598 Expr::FunctionCall { name, args } => {
599 let lower = name.to_ascii_lowercase();
600 if is_aggregate_name(&lower) {
601 let arg = if lower == "count_star" {
602 None
603 } else {
604 args.first().cloned()
605 };
606 let arg2 = if lower == "string_agg" {
610 args.get(1).cloned()
611 } else {
612 None
613 };
614 let canonical = if lower == "every" {
618 "bool_and".to_string()
619 } else {
620 lower
621 };
622 let spec = AggSpec {
623 name: canonical,
624 arg: arg.clone(),
625 arg2: arg2.clone(),
626 distinct: false,
627 order_by: Vec::new(),
628 };
629 if !out.iter().any(|s| {
630 s.name == spec.name
631 && s.arg == spec.arg
632 && s.arg2 == spec.arg2
633 && !s.distinct
634 && s.order_by == spec.order_by
635 }) {
636 out.push(spec);
637 }
638 } else {
641 for a in args {
642 collect_aggregates(a, out);
643 }
644 }
645 }
646 Expr::Binary { lhs, rhs, .. } => {
647 collect_aggregates(lhs, out);
648 collect_aggregates(rhs, out);
649 }
650 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
651 collect_aggregates(expr, out);
652 }
653 Expr::Like { expr, pattern, .. } => {
654 collect_aggregates(expr, out);
655 collect_aggregates(pattern, out);
656 }
657 Expr::Extract { source, .. } => collect_aggregates(source, out),
658 Expr::ScalarSubquery(_)
661 | Expr::Exists { .. }
662 | Expr::InSubquery { .. }
663 | Expr::WindowFunction { .. }
664 | Expr::Literal(_)
665 | Expr::Placeholder(_)
666 | Expr::Column(_) => {}
667 Expr::Array(items) => {
670 for elem in items {
671 collect_aggregates(elem, out);
672 }
673 }
674 Expr::ArraySubscript { target, index } => {
675 collect_aggregates(target, out);
676 collect_aggregates(index, out);
677 }
678 Expr::AnyAll { expr, array, .. } => {
679 collect_aggregates(expr, out);
680 collect_aggregates(array, out);
681 }
682 Expr::Case {
683 operand,
684 branches,
685 else_branch,
686 } => {
687 if let Some(o) = operand {
688 collect_aggregates(o, out);
689 }
690 for (w, t) in branches {
691 collect_aggregates(w, out);
692 collect_aggregates(t, out);
693 }
694 if let Some(e) = else_branch {
695 collect_aggregates(e, out);
696 }
697 }
698 }
699}
700
701fn update_state(
702 st: &mut AggState,
703 name: &str,
704 v: &Value,
705 arg2: Option<&Value>,
706 order_keys: Option<Vec<Value>>,
707) -> Result<(), EvalError> {
708 let is_null = matches!(v, Value::Null);
709 match name {
710 "count_star" => st.count += 1,
711 "count" => {
712 if !is_null {
713 st.count += 1;
714 }
715 }
716 "sum" | "avg" => {
717 if is_null {
718 return Ok(());
719 }
720 st.count += 1;
721 match v {
722 Value::Int(n) => st.sum_int += i64::from(*n),
723 Value::BigInt(n) => st.sum_int += *n,
724 Value::Float(x) => {
725 st.use_float = true;
726 st.sum_float += *x;
727 }
728 other => {
729 return Err(EvalError::TypeMismatch {
730 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
731 });
732 }
733 }
734 }
735 "min" => {
736 if is_null {
737 return Ok(());
738 }
739 match &st.extreme {
740 None => st.extreme = Some(v.clone()),
741 Some(cur) => {
742 if value_cmp(v, cur) == core::cmp::Ordering::Less {
743 st.extreme = Some(v.clone());
744 }
745 }
746 }
747 }
748 "max" => {
749 if is_null {
750 return Ok(());
751 }
752 match &st.extreme {
753 None => st.extreme = Some(v.clone()),
754 Some(cur) => {
755 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
756 st.extreme = Some(v.clone());
757 }
758 }
759 }
760 }
761 "string_agg" => {
769 if let Some(sep) = arg2
770 && let Value::Text(s) = sep
771 {
772 st.separator = Some(s.clone());
773 }
774 if is_null {
775 return Ok(());
776 }
777 if let Value::Text(s) = v {
778 st.items.push(Value::Text(s.clone()));
779 if let Some(k) = order_keys {
780 st.item_keys.push(k);
781 }
782 st.count += 1;
783 } else {
784 return Err(EvalError::TypeMismatch {
785 detail: format!("string_agg requires text value, got {:?}", v.data_type()),
786 });
787 }
788 }
789 "array_agg" => {
795 st.items.push(v.clone());
796 if let Some(k) = order_keys {
797 st.item_keys.push(k);
798 }
799 st.count += 1;
800 }
801 "bool_and" => {
805 if is_null {
806 return Ok(());
807 }
808 let b = match v {
809 Value::Bool(b) => *b,
810 other => {
811 return Err(EvalError::TypeMismatch {
812 detail: format!("bool_and requires bool, got {:?}", other.data_type()),
813 });
814 }
815 };
816 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
817 }
818 "bool_or" => {
821 if is_null {
822 return Ok(());
823 }
824 let b = match v {
825 Value::Bool(b) => *b,
826 other => {
827 return Err(EvalError::TypeMismatch {
828 detail: format!("bool_or requires bool, got {:?}", other.data_type()),
829 });
830 }
831 };
832 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
833 }
834 _ => unreachable!("non-aggregate {name} in update_state"),
835 }
836 Ok(())
837}
838
839#[allow(clippy::cast_precision_loss)]
840fn finalize(name: &str, st: &AggState) -> Value {
841 match name {
842 "count" | "count_star" => Value::BigInt(st.count),
843 "sum" => {
844 if st.count == 0 {
845 Value::Null
846 } else if st.use_float {
847 Value::Float(st.sum_float + (st.sum_int as f64))
848 } else {
849 Value::BigInt(st.sum_int)
850 }
851 }
852 "avg" => {
853 if st.count == 0 {
854 Value::Null
855 } else {
856 let total = if st.use_float {
857 st.sum_float + (st.sum_int as f64)
858 } else {
859 st.sum_int as f64
860 };
861 Value::Float(total / (st.count as f64))
862 }
863 }
864 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
865 "string_agg" => {
869 if st.items.is_empty() {
870 return Value::Null;
871 }
872 let sep = st.separator.clone().unwrap_or_default();
873 let mut out = String::new();
874 for (i, item) in st.items.iter().enumerate() {
875 if i > 0 {
876 out.push_str(&sep);
877 }
878 if let Value::Text(s) = item {
879 out.push_str(s);
880 }
881 }
882 Value::Text(out)
883 }
884 "array_agg" => {
891 if st.items.is_empty() {
892 return Value::Null;
893 }
894 let probe = st.items.iter().find(|v| !v.is_null());
895 match probe.and_then(spg_storage::Value::data_type) {
896 Some(DataType::Int) | Some(DataType::SmallInt) => {
897 let items: Vec<Option<i32>> = st
898 .items
899 .iter()
900 .map(|v| match v {
901 Value::Int(n) => Some(*n),
902 Value::SmallInt(n) => Some(i32::from(*n)),
903 _ => None,
904 })
905 .collect();
906 Value::IntArray(items)
907 }
908 Some(DataType::BigInt) => {
909 let items: Vec<Option<i64>> = st
910 .items
911 .iter()
912 .map(|v| match v {
913 Value::BigInt(n) => Some(*n),
914 _ => None,
915 })
916 .collect();
917 Value::BigIntArray(items)
918 }
919 _ => {
920 let items: Vec<Option<String>> = st
921 .items
922 .iter()
923 .map(|v| match v {
924 Value::Text(s) => Some(s.clone()),
925 Value::Null => None,
926 other => Some(format!("{other:?}")),
927 })
928 .collect();
929 Value::TextArray(items)
930 }
931 }
932 }
933 "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
937 _ => unreachable!(),
938 }
939}
940
941fn infer_agg_type(spec: &AggSpec, schema_cols: &[ColumnSchema]) -> DataType {
942 let arg_ty = spec
946 .arg
947 .as_ref()
948 .and_then(|a| crate::describe::describe_expr(a, schema_cols))
949 .map(|shape| shape.ty);
950 match spec.name.as_str() {
951 "count" | "count_star" => DataType::BigInt,
952 "sum" => match arg_ty {
953 Some(DataType::Float) => DataType::Float,
954 _ => DataType::BigInt,
955 },
956 "avg" => DataType::Float,
957 "string_agg" => DataType::Text,
959 "array_agg" => match arg_ty {
960 Some(DataType::Int | DataType::SmallInt) => DataType::IntArray,
961 Some(DataType::BigInt) => DataType::BigIntArray,
962 _ => DataType::TextArray,
963 },
964 "bool_and" | "bool_or" => DataType::Bool,
967 _ => arg_ty.unwrap_or(DataType::Text),
969 }
970}
971
972fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
973 if let Expr::Column(c) = e
974 && let Some(s) = synth.iter().find(|s| s.name == c.name)
975 {
976 return s.ty;
977 }
978 crate::describe::describe_expr(e, synth)
984 .map(|shape| shape.ty)
985 .unwrap_or(DataType::Text)
986}
987
988fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
989 if let Expr::AggregateOrdered {
992 call,
993 order_by,
994 distinct,
995 } = e
996 && let Expr::FunctionCall { name, args } = call.as_ref()
997 {
998 let lower = name.to_ascii_lowercase();
999 if is_aggregate_name(&lower) {
1000 let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
1001 let arg = args.first().cloned();
1002 let arg2 = if lower == "string_agg" {
1003 args.get(1).cloned()
1004 } else {
1005 None
1006 };
1007 for (i, spec) in aggs.iter().enumerate() {
1008 if spec.name == canonical
1009 && spec.arg == arg
1010 && spec.arg2 == arg2
1011 && spec.distinct == *distinct
1012 && spec.order_by == *order_by
1013 {
1014 return Expr::Column(spg_sql::ast::ColumnName {
1015 qualifier: None,
1016 name: format!("__agg_{i}"),
1017 });
1018 }
1019 }
1020 }
1021 }
1022 if let Expr::FunctionCall { name, args } = e {
1024 let lower = name.to_ascii_lowercase();
1025 if is_aggregate_name(&lower) {
1026 let arg = if lower == "count_star" {
1027 None
1028 } else {
1029 args.first().cloned()
1030 };
1031 let arg2 = if lower == "string_agg" {
1034 args.get(1).cloned()
1035 } else {
1036 None
1037 };
1038 let canonical: &str = if lower == "every" {
1042 "bool_and"
1043 } else {
1044 lower.as_str()
1045 };
1046 for (i, spec) in aggs.iter().enumerate() {
1047 if spec.name == canonical
1048 && spec.arg == arg
1049 && spec.arg2 == arg2
1050 && !spec.distinct
1051 && spec.order_by.is_empty()
1052 {
1053 return Expr::Column(spg_sql::ast::ColumnName {
1054 qualifier: None,
1055 name: format!("__agg_{i}"),
1056 });
1057 }
1058 }
1059 }
1060 }
1061 for (i, g) in group_exprs.iter().enumerate() {
1063 if g == e {
1064 return Expr::Column(spg_sql::ast::ColumnName {
1065 qualifier: None,
1066 name: format!("__grp_{i}"),
1067 });
1068 }
1069 }
1070 match e {
1072 Expr::AggregateOrdered {
1073 call,
1074 order_by,
1075 distinct,
1076 } => Expr::AggregateOrdered {
1077 call: Box::new(rewrite_expr(call, group_exprs, aggs)),
1078 distinct: *distinct,
1079 order_by: order_by
1080 .iter()
1081 .map(|o| spg_sql::ast::OrderBy {
1082 expr: rewrite_expr(&o.expr, group_exprs, aggs),
1083 desc: o.desc,
1084 nulls_first: o.nulls_first,
1085 })
1086 .collect(),
1087 },
1088 Expr::Binary { lhs, op, rhs } => Expr::Binary {
1089 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
1090 op: *op,
1091 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
1092 },
1093 Expr::Unary { op, expr } => Expr::Unary {
1094 op: *op,
1095 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1096 },
1097 Expr::Cast { expr, target } => Expr::Cast {
1098 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1099 target: *target,
1100 },
1101 Expr::IsNull { expr, negated } => Expr::IsNull {
1102 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1103 negated: *negated,
1104 },
1105 Expr::FunctionCall { name, args } => Expr::FunctionCall {
1106 name: name.clone(),
1107 args: args
1108 .iter()
1109 .map(|a| rewrite_expr(a, group_exprs, aggs))
1110 .collect(),
1111 },
1112 Expr::Like {
1113 expr,
1114 pattern,
1115 negated,
1116 case_insensitive,
1117 } => Expr::Like {
1118 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1119 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
1120 negated: *negated,
1121 case_insensitive: *case_insensitive,
1122 },
1123 Expr::Extract { field, source } => Expr::Extract {
1124 field: *field,
1125 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
1126 },
1127 Expr::ScalarSubquery(s) => {
1133 Expr::ScalarSubquery(Box::new(rewrite_group_keys_in_select(s, group_exprs)))
1134 }
1135 Expr::Exists { subquery, negated } => Expr::Exists {
1136 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1137 negated: *negated,
1138 },
1139 Expr::InSubquery {
1140 expr,
1141 subquery,
1142 negated,
1143 } => Expr::InSubquery {
1144 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1145 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1146 negated: *negated,
1147 },
1148 Expr::WindowFunction { .. } | Expr::Literal(_) | Expr::Placeholder(_) | Expr::Column(_) => {
1151 e.clone()
1152 }
1153 Expr::Array(items) => Expr::Array(
1155 items
1156 .iter()
1157 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
1158 .collect(),
1159 ),
1160 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
1161 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
1162 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
1163 },
1164 Expr::AnyAll {
1165 expr,
1166 op,
1167 array,
1168 is_any,
1169 } => Expr::AnyAll {
1170 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1171 op: *op,
1172 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
1173 is_any: *is_any,
1174 },
1175 Expr::Case {
1176 operand,
1177 branches,
1178 else_branch,
1179 } => Expr::Case {
1180 operand: operand
1181 .as_deref()
1182 .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
1183 branches: branches
1184 .iter()
1185 .map(|(w, t)| {
1186 (
1187 rewrite_expr(w, group_exprs, aggs),
1188 rewrite_expr(t, group_exprs, aggs),
1189 )
1190 })
1191 .collect(),
1192 else_branch: else_branch
1193 .as_deref()
1194 .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
1195 },
1196 }
1197}
1198
1199fn rewrite_group_keys_in_select(
1204 s: &spg_sql::ast::SelectStatement,
1205 group_exprs: &[Expr],
1206) -> spg_sql::ast::SelectStatement {
1207 let mut out = s.clone();
1208 let _ = crate::walk_select_exprs_mut(&mut out, &mut |e| {
1209 *e = rewrite_expr(e, group_exprs, &[]);
1210 Ok(())
1211 });
1212 out
1213}
1214
1215pub(crate) fn encode_key(vals: &[Value]) -> String {
1217 let mut out = String::new();
1218 for v in vals {
1219 match v {
1220 Value::Null => out.push_str("N|"),
1221 Value::SmallInt(n) => {
1222 out.push('s');
1223 out.push_str(&n.to_string());
1224 out.push('|');
1225 }
1226 Value::Int(n) => {
1227 out.push('I');
1228 out.push_str(&n.to_string());
1229 out.push('|');
1230 }
1231 Value::BigInt(n) => {
1232 out.push('B');
1233 out.push_str(&n.to_string());
1234 out.push('|');
1235 }
1236 Value::Float(x) => {
1237 out.push('F');
1238 out.push_str(&x.to_string());
1239 out.push('|');
1240 }
1241 Value::Bool(b) => {
1242 out.push(if *b { 'T' } else { 'f' });
1243 out.push('|');
1244 }
1245 Value::Text(s) => {
1246 out.push('S');
1247 out.push_str(s);
1248 out.push('|');
1249 }
1250 Value::Vector(v) => {
1251 out.push('V');
1252 for x in v {
1253 out.push_str(&x.to_string());
1254 out.push(',');
1255 }
1256 out.push('|');
1257 }
1258 Value::Sq8Vector(q) => {
1264 out.push('Q');
1265 out.push_str(&q.min.to_string());
1266 out.push('@');
1267 out.push_str(&q.max.to_string());
1268 out.push(':');
1269 for b in &q.bytes {
1270 out.push_str(&b.to_string());
1271 out.push(',');
1272 }
1273 out.push('|');
1274 }
1275 Value::HalfVector(h) => {
1279 out.push('H');
1280 for b in &h.bytes {
1281 out.push_str(&b.to_string());
1282 out.push(',');
1283 }
1284 out.push('|');
1285 }
1286 Value::Numeric { scaled, scale } => {
1287 out.push('D');
1288 out.push_str(&scaled.to_string());
1289 out.push('@');
1290 out.push_str(&scale.to_string());
1291 out.push('|');
1292 }
1293 Value::Date(d) => {
1294 out.push('d');
1295 out.push_str(&d.to_string());
1296 out.push('|');
1297 }
1298 Value::Timestamp(t) => {
1299 out.push('t');
1300 out.push_str(&t.to_string());
1301 out.push('|');
1302 }
1303 Value::Interval { months, micros } => {
1304 out.push('i');
1305 out.push_str(&months.to_string());
1306 out.push('m');
1307 out.push_str(µs.to_string());
1308 out.push('|');
1309 }
1310 Value::Json(s) => {
1311 out.push('j');
1312 out.push_str(s);
1313 out.push('|');
1314 }
1315 _ => {
1320 out.push('?');
1321 out.push_str(&format!("{v:?}"));
1322 out.push('|');
1323 }
1324 }
1325 }
1326 out
1327}
1328
1329#[allow(clippy::cast_precision_loss)]
1330fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1331 use core::cmp::Ordering::Equal;
1332 match (a, b) {
1333 (Value::Null, Value::Null) => Equal,
1334 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
1336 (Value::Int(x), Value::Int(y)) => x.cmp(y),
1337 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1338 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1339 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1340 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1341 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1342 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1343 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1344 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1345 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1346 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1347 _ => Equal,
1348 }
1349}