1use alloc::boxed::Box;
24use alloc::collections::BTreeMap;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::AggregateOrdered { .. } => true,
65 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
66 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
67 contains_aggregate(expr)
68 }
69 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
70 Expr::Extract { source, .. } => contains_aggregate(source),
71 Expr::ScalarSubquery(_)
76 | Expr::Exists { .. }
77 | Expr::InSubquery { .. }
78 | Expr::WindowFunction { .. }
79 | Expr::Literal(_)
80 | Expr::Placeholder(_)
81 | Expr::Column(_) => false,
82 Expr::Array(items) => items.iter().any(contains_aggregate),
86 Expr::ArraySubscript { target, index } => {
87 contains_aggregate(target) || contains_aggregate(index)
88 }
89 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
90 Expr::Case {
93 operand,
94 branches,
95 else_branch,
96 } => {
97 operand.as_deref().is_some_and(contains_aggregate)
98 || branches
99 .iter()
100 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
101 || else_branch.as_deref().is_some_and(contains_aggregate)
102 }
103 }
104}
105
106pub fn is_aggregate_name(name: &str) -> bool {
107 matches!(
108 name.to_ascii_lowercase().as_str(),
109 "count"
110 | "count_star"
111 | "sum"
112 | "min"
113 | "max"
114 | "avg"
115 | "string_agg"
120 | "array_agg"
121 | "bool_and"
124 | "bool_or"
125 | "every"
126 )
127}
128
129#[derive(Debug, Default, Clone)]
131struct AggState {
132 count: i64,
133 sum_int: i64,
134 sum_float: f64,
135 extreme: Option<Value>,
136 use_float: bool,
137 items: Vec<Value>,
144 item_keys: Vec<Vec<Value>>,
148 separator: Option<String>,
154 bool_acc: Option<bool>,
158}
159
160#[derive(Debug, Clone)]
161struct AggSpec {
162 name: String, arg: Option<Expr>,
166 arg2: Option<Expr>,
172 order_by: Vec<spg_sql::ast::OrderBy>,
178}
179
180#[derive(Debug)]
183pub struct AggResult {
184 pub columns: Vec<ColumnSchema>,
185 pub rows: Vec<Row>,
186}
187
188#[allow(clippy::too_many_lines)]
191pub fn run(
192 stmt: &SelectStatement,
193 rows: &[&Row],
194 schema_cols: &[ColumnSchema],
195 table_alias: Option<&str>,
196) -> Result<AggResult, EvalError> {
197 let ctx = EvalContext::new(schema_cols, table_alias);
198 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
199
200 let mut agg_specs: Vec<AggSpec> = Vec::new();
202 for item in &stmt.items {
203 if let SelectItem::Expr { expr, .. } = item {
204 collect_aggregates(expr, &mut agg_specs);
205 }
206 }
207 for o in &stmt.order_by {
208 collect_aggregates(&o.expr, &mut agg_specs);
209 }
210 if let Some(h) = &stmt.having {
211 collect_aggregates(h, &mut agg_specs);
212 }
213 validate_agg_arities(stmt, &agg_specs)?;
219
220 let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
223 let mut key_order: Vec<String> = Vec::new();
224 if rows.is_empty() && group_exprs.is_empty() {
227 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
229 groups.insert(String::new(), (Vec::new(), init));
230 key_order.push(String::new());
231 }
232
233 for row in rows {
234 let group_vals: Vec<Value> = group_exprs
235 .iter()
236 .map(|g| eval::eval_expr(g, row, &ctx))
237 .collect::<Result<_, _>>()?;
238 let mut key_vals = group_vals.clone();
244 for (i, g) in group_exprs.iter().enumerate() {
245 if matches!(
246 eval::column_collation(g, &ctx),
247 Some(spg_storage::Collation::CaseInsensitive)
248 ) {
249 if let Value::Text(s) = &key_vals[i] {
250 key_vals[i] = Value::Text(s.to_ascii_lowercase());
251 }
252 }
253 }
254 let key = encode_key(&key_vals);
255 let entry = groups.entry(key.clone()).or_insert_with(|| {
256 key_order.push(key.clone());
257 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
258 (group_vals.clone(), init)
259 });
260 for (i, spec) in agg_specs.iter().enumerate() {
261 let arg_val = match &spec.arg {
262 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
264 };
265 let arg2_val = match &spec.arg2 {
271 None => None,
272 Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
273 };
274 let order_keys = if spec.order_by.is_empty() {
277 None
278 } else {
279 let mut keys = Vec::with_capacity(spec.order_by.len());
280 for o in &spec.order_by {
281 keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
282 }
283 Some(keys)
284 };
285 update_state(
286 &mut entry.1[i],
287 &spec.name,
288 &arg_val,
289 arg2_val.as_ref(),
290 order_keys,
291 )?;
292 }
293 }
294
295 let group_types: Vec<DataType> = if rows.is_empty() {
297 group_exprs.iter().map(|_| DataType::Text).collect()
300 } else {
301 let probe = rows[0];
302 group_exprs
303 .iter()
304 .map(|g| {
305 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
306 })
307 .collect::<Result<_, _>>()?
308 };
309 let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
310 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
311 for (i, ty) in group_types.iter().enumerate() {
312 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
313 }
314 for (i, ty) in agg_types.iter().enumerate() {
315 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
316 }
317
318 let mut synth_rows: Vec<Row> = Vec::new();
320 for k in &key_order {
321 let (gvals, states) = &groups[k];
322 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
323 values.extend(gvals.iter().cloned());
324 for (i, st) in states.iter().enumerate() {
325 let st_sorted;
329 let st_final: &AggState =
330 if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
331 let mut idx: Vec<usize> = (0..st.items.len()).collect();
332 let ob = &agg_specs[i].order_by;
333 idx.sort_by(|&x, &y| {
334 for (k, o) in ob.iter().enumerate() {
335 let cmp = crate::order_by_value_cmp(
336 o.desc,
337 o.nulls_first,
338 &st.item_keys[x][k],
339 &st.item_keys[y][k],
340 );
341 if cmp != core::cmp::Ordering::Equal {
342 return cmp;
343 }
344 }
345 core::cmp::Ordering::Equal
346 });
347 let mut sorted = st.clone();
348 sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
349 st_sorted = sorted;
350 &st_sorted
351 } else {
352 st
353 };
354 values.push(finalize(&agg_specs[i].name, st_final));
355 }
356 synth_rows.push(Row::new(values));
357 }
358
359 let columns: Vec<ColumnSchema> = stmt
364 .items
365 .iter()
366 .map(|item| match item {
367 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
368 detail: "SELECT * with aggregates is not supported".into(),
369 }),
370 SelectItem::Expr { expr, alias } => {
371 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
372 let name = alias.clone().unwrap_or_else(|| expr.to_string());
373 Ok(ColumnSchema::new(
374 name,
375 agg_or_group_type(&rewritten, &synth_schema),
376 true,
377 ))
378 }
379 })
380 .collect::<Result<_, _>>()?;
381
382 let synth_ctx = EvalContext::new(&synth_schema, None);
387 let having_rewritten = stmt
388 .having
389 .as_ref()
390 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
391 let mut kept_synth: Vec<Row> = Vec::new();
392 let mut out_rows: Vec<Row> = Vec::new();
393 for srow in synth_rows {
394 if let Some(h) = &having_rewritten {
395 let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
396 if !matches!(cond, Value::Bool(true)) {
397 continue;
398 }
399 }
400 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
401 for item in &stmt.items {
402 if let SelectItem::Expr { expr, .. } = item {
403 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
404 values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
405 }
406 }
407 kept_synth.push(srow);
408 out_rows.push(Row::new(values));
409 }
410
411 if !stmt.order_by.is_empty() {
414 let rewritten: Vec<Expr> = stmt
417 .order_by
418 .iter()
419 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
420 .collect();
421 let keys_meta: Vec<(bool, Option<bool>)> = stmt
422 .order_by
423 .iter()
424 .map(|o| (o.desc, o.nulls_first))
425 .collect();
426 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
427 .into_iter()
428 .zip(out_rows)
429 .map(|(s, o)| {
430 let mut keys = Vec::with_capacity(rewritten.len());
431 for e in &rewritten {
432 keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
433 }
434 Ok::<_, EvalError>((keys, o))
435 })
436 .collect::<Result<_, _>>()?;
437 tagged.sort_by(|a, b| {
438 use core::cmp::Ordering;
439 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
440 let (desc, nf) = keys_meta[i];
441 let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
442 if cmp != Ordering::Equal {
443 return cmp;
444 }
445 }
446 Ordering::Equal
447 });
448 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
449 }
450
451 Ok(AggResult {
452 columns,
453 rows: out_rows,
454 })
455}
456
457fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
463 fn walk(e: &Expr) -> Result<(), EvalError> {
464 if let Expr::FunctionCall { name, args } = e {
465 let lower = name.to_ascii_lowercase();
466 let expected: Option<usize> = match lower.as_str() {
467 "count_star" => Some(0),
468 "count" | "sum" | "avg" | "min" | "max" | "array_agg"
469 | "bool_and" | "bool_or" | "every" => Some(1),
473 "string_agg" => Some(2),
474 _ => None,
475 };
476 if let Some(want) = expected
477 && args.len() != want
478 {
479 return Err(EvalError::TypeMismatch {
480 detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
481 });
482 }
483 for a in args {
484 walk(a)?;
485 }
486 } else if let Expr::Binary { lhs, rhs, .. } = e {
487 walk(lhs)?;
488 walk(rhs)?;
489 } else if let Expr::Unary { expr, .. }
490 | Expr::Cast { expr, .. }
491 | Expr::IsNull { expr, .. } = e
492 {
493 walk(expr)?;
494 }
495 Ok(())
496 }
497 for item in &stmt.items {
498 if let SelectItem::Expr { expr, .. } = item {
499 walk(expr)?;
500 }
501 }
502 for o in &stmt.order_by {
503 walk(&o.expr)?;
504 }
505 if let Some(h) = &stmt.having {
506 walk(h)?;
507 }
508 Ok(())
509}
510
511fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
512 match e {
513 Expr::AggregateOrdered { call, order_by } => {
516 if let Expr::FunctionCall { name, args } = call.as_ref() {
517 let lower = name.to_ascii_lowercase();
518 if is_aggregate_name(&lower) {
519 let canonical = if lower == "every" {
520 "bool_and".to_string()
521 } else {
522 lower
523 };
524 let spec = AggSpec {
525 name: canonical,
526 arg: args.first().cloned(),
527 arg2: if name.eq_ignore_ascii_case("string_agg") {
528 args.get(1).cloned()
529 } else {
530 None
531 },
532 order_by: order_by.clone(),
533 };
534 if !out.iter().any(|s| {
535 s.name == spec.name
536 && s.arg == spec.arg
537 && s.arg2 == spec.arg2
538 && s.order_by == spec.order_by
539 }) {
540 out.push(spec);
541 }
542 return;
543 }
544 }
545 collect_aggregates(call, out);
546 for o in order_by {
547 collect_aggregates(&o.expr, out);
548 }
549 }
550 Expr::FunctionCall { name, args } => {
551 let lower = name.to_ascii_lowercase();
552 if is_aggregate_name(&lower) {
553 let arg = if lower == "count_star" {
554 None
555 } else {
556 args.first().cloned()
557 };
558 let arg2 = if lower == "string_agg" {
562 args.get(1).cloned()
563 } else {
564 None
565 };
566 let canonical = if lower == "every" {
570 "bool_and".to_string()
571 } else {
572 lower
573 };
574 let spec = AggSpec {
575 name: canonical,
576 arg: arg.clone(),
577 arg2: arg2.clone(),
578 order_by: Vec::new(),
579 };
580 if !out.iter().any(|s| {
581 s.name == spec.name
582 && s.arg == spec.arg
583 && s.arg2 == spec.arg2
584 && s.order_by == spec.order_by
585 }) {
586 out.push(spec);
587 }
588 } else {
591 for a in args {
592 collect_aggregates(a, out);
593 }
594 }
595 }
596 Expr::Binary { lhs, rhs, .. } => {
597 collect_aggregates(lhs, out);
598 collect_aggregates(rhs, out);
599 }
600 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
601 collect_aggregates(expr, out);
602 }
603 Expr::Like { expr, pattern, .. } => {
604 collect_aggregates(expr, out);
605 collect_aggregates(pattern, out);
606 }
607 Expr::Extract { source, .. } => collect_aggregates(source, out),
608 Expr::ScalarSubquery(_)
611 | Expr::Exists { .. }
612 | Expr::InSubquery { .. }
613 | Expr::WindowFunction { .. }
614 | Expr::Literal(_)
615 | Expr::Placeholder(_)
616 | Expr::Column(_) => {}
617 Expr::Array(items) => {
620 for elem in items {
621 collect_aggregates(elem, out);
622 }
623 }
624 Expr::ArraySubscript { target, index } => {
625 collect_aggregates(target, out);
626 collect_aggregates(index, out);
627 }
628 Expr::AnyAll { expr, array, .. } => {
629 collect_aggregates(expr, out);
630 collect_aggregates(array, out);
631 }
632 Expr::Case {
633 operand,
634 branches,
635 else_branch,
636 } => {
637 if let Some(o) = operand {
638 collect_aggregates(o, out);
639 }
640 for (w, t) in branches {
641 collect_aggregates(w, out);
642 collect_aggregates(t, out);
643 }
644 if let Some(e) = else_branch {
645 collect_aggregates(e, out);
646 }
647 }
648 }
649}
650
651fn update_state(
652 st: &mut AggState,
653 name: &str,
654 v: &Value,
655 arg2: Option<&Value>,
656 order_keys: Option<Vec<Value>>,
657) -> Result<(), EvalError> {
658 let is_null = matches!(v, Value::Null);
659 match name {
660 "count_star" => st.count += 1,
661 "count" => {
662 if !is_null {
663 st.count += 1;
664 }
665 }
666 "sum" | "avg" => {
667 if is_null {
668 return Ok(());
669 }
670 st.count += 1;
671 match v {
672 Value::Int(n) => st.sum_int += i64::from(*n),
673 Value::BigInt(n) => st.sum_int += *n,
674 Value::Float(x) => {
675 st.use_float = true;
676 st.sum_float += *x;
677 }
678 other => {
679 return Err(EvalError::TypeMismatch {
680 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
681 });
682 }
683 }
684 }
685 "min" => {
686 if is_null {
687 return Ok(());
688 }
689 match &st.extreme {
690 None => st.extreme = Some(v.clone()),
691 Some(cur) => {
692 if value_cmp(v, cur) == core::cmp::Ordering::Less {
693 st.extreme = Some(v.clone());
694 }
695 }
696 }
697 }
698 "max" => {
699 if is_null {
700 return Ok(());
701 }
702 match &st.extreme {
703 None => st.extreme = Some(v.clone()),
704 Some(cur) => {
705 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
706 st.extreme = Some(v.clone());
707 }
708 }
709 }
710 }
711 "string_agg" => {
719 if let Some(sep) = arg2
720 && let Value::Text(s) = sep
721 {
722 st.separator = Some(s.clone());
723 }
724 if is_null {
725 return Ok(());
726 }
727 if let Value::Text(s) = v {
728 st.items.push(Value::Text(s.clone()));
729 if let Some(k) = order_keys {
730 st.item_keys.push(k);
731 }
732 st.count += 1;
733 } else {
734 return Err(EvalError::TypeMismatch {
735 detail: format!("string_agg requires text value, got {:?}", v.data_type()),
736 });
737 }
738 }
739 "array_agg" => {
745 st.items.push(v.clone());
746 if let Some(k) = order_keys {
747 st.item_keys.push(k);
748 }
749 st.count += 1;
750 }
751 "bool_and" => {
755 if is_null {
756 return Ok(());
757 }
758 let b = match v {
759 Value::Bool(b) => *b,
760 other => {
761 return Err(EvalError::TypeMismatch {
762 detail: format!("bool_and requires bool, got {:?}", other.data_type()),
763 });
764 }
765 };
766 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
767 }
768 "bool_or" => {
771 if is_null {
772 return Ok(());
773 }
774 let b = match v {
775 Value::Bool(b) => *b,
776 other => {
777 return Err(EvalError::TypeMismatch {
778 detail: format!("bool_or requires bool, got {:?}", other.data_type()),
779 });
780 }
781 };
782 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
783 }
784 _ => unreachable!("non-aggregate {name} in update_state"),
785 }
786 Ok(())
787}
788
789#[allow(clippy::cast_precision_loss)]
790fn finalize(name: &str, st: &AggState) -> Value {
791 match name {
792 "count" | "count_star" => Value::BigInt(st.count),
793 "sum" => {
794 if st.count == 0 {
795 Value::Null
796 } else if st.use_float {
797 Value::Float(st.sum_float + (st.sum_int as f64))
798 } else {
799 Value::BigInt(st.sum_int)
800 }
801 }
802 "avg" => {
803 if st.count == 0 {
804 Value::Null
805 } else {
806 let total = if st.use_float {
807 st.sum_float + (st.sum_int as f64)
808 } else {
809 st.sum_int as f64
810 };
811 Value::Float(total / (st.count as f64))
812 }
813 }
814 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
815 "string_agg" => {
819 if st.items.is_empty() {
820 return Value::Null;
821 }
822 let sep = st.separator.clone().unwrap_or_default();
823 let mut out = String::new();
824 for (i, item) in st.items.iter().enumerate() {
825 if i > 0 {
826 out.push_str(&sep);
827 }
828 if let Value::Text(s) = item {
829 out.push_str(s);
830 }
831 }
832 Value::Text(out)
833 }
834 "array_agg" => {
841 if st.items.is_empty() {
842 return Value::Null;
843 }
844 let probe = st.items.iter().find(|v| !v.is_null());
845 match probe.and_then(spg_storage::Value::data_type) {
846 Some(DataType::Int) | Some(DataType::SmallInt) => {
847 let items: Vec<Option<i32>> = st
848 .items
849 .iter()
850 .map(|v| match v {
851 Value::Int(n) => Some(*n),
852 Value::SmallInt(n) => Some(i32::from(*n)),
853 _ => None,
854 })
855 .collect();
856 Value::IntArray(items)
857 }
858 Some(DataType::BigInt) => {
859 let items: Vec<Option<i64>> = st
860 .items
861 .iter()
862 .map(|v| match v {
863 Value::BigInt(n) => Some(*n),
864 _ => None,
865 })
866 .collect();
867 Value::BigIntArray(items)
868 }
869 _ => {
870 let items: Vec<Option<String>> = st
871 .items
872 .iter()
873 .map(|v| match v {
874 Value::Text(s) => Some(s.clone()),
875 Value::Null => None,
876 other => Some(format!("{other:?}")),
877 })
878 .collect();
879 Value::TextArray(items)
880 }
881 }
882 }
883 "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
887 _ => unreachable!(),
888 }
889}
890
891fn infer_agg_type(spec: &AggSpec) -> DataType {
892 match spec.name.as_str() {
893 "count" | "count_star" | "sum" => DataType::BigInt,
897 "avg" => DataType::Float,
898 "string_agg" => DataType::Text,
900 "array_agg" => DataType::TextArray,
907 "bool_and" | "bool_or" => DataType::Bool,
910 _ => DataType::Text,
913 }
914}
915
916fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
917 if let Expr::Column(c) = e
918 && let Some(s) = synth.iter().find(|s| s.name == c.name)
919 {
920 return s.ty;
921 }
922 DataType::Text
925}
926
927fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
928 if let Expr::AggregateOrdered { call, order_by } = e
931 && let Expr::FunctionCall { name, args } = call.as_ref()
932 {
933 let lower = name.to_ascii_lowercase();
934 if is_aggregate_name(&lower) {
935 let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
936 let arg = args.first().cloned();
937 let arg2 = if lower == "string_agg" {
938 args.get(1).cloned()
939 } else {
940 None
941 };
942 for (i, spec) in aggs.iter().enumerate() {
943 if spec.name == canonical
944 && spec.arg == arg
945 && spec.arg2 == arg2
946 && spec.order_by == *order_by
947 {
948 return Expr::Column(spg_sql::ast::ColumnName {
949 qualifier: None,
950 name: format!("__agg_{i}"),
951 });
952 }
953 }
954 }
955 }
956 if let Expr::FunctionCall { name, args } = e {
958 let lower = name.to_ascii_lowercase();
959 if is_aggregate_name(&lower) {
960 let arg = if lower == "count_star" {
961 None
962 } else {
963 args.first().cloned()
964 };
965 let arg2 = if lower == "string_agg" {
968 args.get(1).cloned()
969 } else {
970 None
971 };
972 let canonical: &str = if lower == "every" {
976 "bool_and"
977 } else {
978 lower.as_str()
979 };
980 for (i, spec) in aggs.iter().enumerate() {
981 if spec.name == canonical
982 && spec.arg == arg
983 && spec.arg2 == arg2
984 && spec.order_by.is_empty()
985 {
986 return Expr::Column(spg_sql::ast::ColumnName {
987 qualifier: None,
988 name: format!("__agg_{i}"),
989 });
990 }
991 }
992 }
993 }
994 for (i, g) in group_exprs.iter().enumerate() {
996 if g == e {
997 return Expr::Column(spg_sql::ast::ColumnName {
998 qualifier: None,
999 name: format!("__grp_{i}"),
1000 });
1001 }
1002 }
1003 match e {
1005 Expr::AggregateOrdered { call, order_by } => Expr::AggregateOrdered {
1006 call: Box::new(rewrite_expr(call, group_exprs, aggs)),
1007 order_by: order_by
1008 .iter()
1009 .map(|o| spg_sql::ast::OrderBy {
1010 expr: rewrite_expr(&o.expr, group_exprs, aggs),
1011 desc: o.desc,
1012 nulls_first: o.nulls_first,
1013 })
1014 .collect(),
1015 },
1016 Expr::Binary { lhs, op, rhs } => Expr::Binary {
1017 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
1018 op: *op,
1019 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
1020 },
1021 Expr::Unary { op, expr } => Expr::Unary {
1022 op: *op,
1023 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1024 },
1025 Expr::Cast { expr, target } => Expr::Cast {
1026 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1027 target: *target,
1028 },
1029 Expr::IsNull { expr, negated } => Expr::IsNull {
1030 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1031 negated: *negated,
1032 },
1033 Expr::FunctionCall { name, args } => Expr::FunctionCall {
1034 name: name.clone(),
1035 args: args
1036 .iter()
1037 .map(|a| rewrite_expr(a, group_exprs, aggs))
1038 .collect(),
1039 },
1040 Expr::Like {
1041 expr,
1042 pattern,
1043 negated,
1044 } => Expr::Like {
1045 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1046 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
1047 negated: *negated,
1048 },
1049 Expr::Extract { field, source } => Expr::Extract {
1050 field: *field,
1051 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
1052 },
1053 Expr::ScalarSubquery(_)
1056 | Expr::Exists { .. }
1057 | Expr::InSubquery { .. }
1058 | Expr::WindowFunction { .. }
1059 | Expr::Literal(_)
1060 | Expr::Placeholder(_)
1061 | Expr::Column(_) => e.clone(),
1062 Expr::Array(items) => Expr::Array(
1064 items
1065 .iter()
1066 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
1067 .collect(),
1068 ),
1069 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
1070 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
1071 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
1072 },
1073 Expr::AnyAll {
1074 expr,
1075 op,
1076 array,
1077 is_any,
1078 } => Expr::AnyAll {
1079 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1080 op: *op,
1081 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
1082 is_any: *is_any,
1083 },
1084 Expr::Case {
1085 operand,
1086 branches,
1087 else_branch,
1088 } => Expr::Case {
1089 operand: operand
1090 .as_deref()
1091 .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
1092 branches: branches
1093 .iter()
1094 .map(|(w, t)| {
1095 (
1096 rewrite_expr(w, group_exprs, aggs),
1097 rewrite_expr(t, group_exprs, aggs),
1098 )
1099 })
1100 .collect(),
1101 else_branch: else_branch
1102 .as_deref()
1103 .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
1104 },
1105 }
1106}
1107
1108fn encode_key(vals: &[Value]) -> String {
1110 let mut out = String::new();
1111 for v in vals {
1112 match v {
1113 Value::Null => out.push_str("N|"),
1114 Value::SmallInt(n) => {
1115 out.push('s');
1116 out.push_str(&n.to_string());
1117 out.push('|');
1118 }
1119 Value::Int(n) => {
1120 out.push('I');
1121 out.push_str(&n.to_string());
1122 out.push('|');
1123 }
1124 Value::BigInt(n) => {
1125 out.push('B');
1126 out.push_str(&n.to_string());
1127 out.push('|');
1128 }
1129 Value::Float(x) => {
1130 out.push('F');
1131 out.push_str(&x.to_string());
1132 out.push('|');
1133 }
1134 Value::Bool(b) => {
1135 out.push(if *b { 'T' } else { 'f' });
1136 out.push('|');
1137 }
1138 Value::Text(s) => {
1139 out.push('S');
1140 out.push_str(s);
1141 out.push('|');
1142 }
1143 Value::Vector(v) => {
1144 out.push('V');
1145 for x in v {
1146 out.push_str(&x.to_string());
1147 out.push(',');
1148 }
1149 out.push('|');
1150 }
1151 Value::Sq8Vector(q) => {
1157 out.push('Q');
1158 out.push_str(&q.min.to_string());
1159 out.push('@');
1160 out.push_str(&q.max.to_string());
1161 out.push(':');
1162 for b in &q.bytes {
1163 out.push_str(&b.to_string());
1164 out.push(',');
1165 }
1166 out.push('|');
1167 }
1168 Value::HalfVector(h) => {
1172 out.push('H');
1173 for b in &h.bytes {
1174 out.push_str(&b.to_string());
1175 out.push(',');
1176 }
1177 out.push('|');
1178 }
1179 Value::Numeric { scaled, scale } => {
1180 out.push('D');
1181 out.push_str(&scaled.to_string());
1182 out.push('@');
1183 out.push_str(&scale.to_string());
1184 out.push('|');
1185 }
1186 Value::Date(d) => {
1187 out.push('d');
1188 out.push_str(&d.to_string());
1189 out.push('|');
1190 }
1191 Value::Timestamp(t) => {
1192 out.push('t');
1193 out.push_str(&t.to_string());
1194 out.push('|');
1195 }
1196 Value::Interval { months, micros } => {
1197 out.push('i');
1198 out.push_str(&months.to_string());
1199 out.push('m');
1200 out.push_str(µs.to_string());
1201 out.push('|');
1202 }
1203 Value::Json(s) => {
1204 out.push('j');
1205 out.push_str(s);
1206 out.push('|');
1207 }
1208 _ => {
1213 out.push('?');
1214 out.push_str(&format!("{v:?}"));
1215 out.push('|');
1216 }
1217 }
1218 }
1219 out
1220}
1221
1222#[allow(clippy::cast_precision_loss)]
1223fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1224 use core::cmp::Ordering::Equal;
1225 match (a, b) {
1226 (Value::Null, Value::Null) => Equal,
1227 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
1229 (Value::Int(x), Value::Int(y)) => x.cmp(y),
1230 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1231 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1232 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1233 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1234 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1235 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1236 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1237 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1238 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1239 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1240 _ => Equal,
1241 }
1242}