Skip to main content

spg_engine/
aggregate.rs

1//! Aggregate executor.
2//!
3//! Handles `SELECT … <aggs> … [GROUP BY …]` queries. The planning strategy
4//! is straightforward:
5//!
6//! 1. Walk the SELECT (and ORDER BY) expressions to find every aggregate
7//!    function call. Dedupe by AST equality and assign each `__agg_<i>`.
8//! 2. Same for every `GROUP BY` expression: assign `__grp_<j>`.
9//! 3. Stream the WHERE-filtered rows, group by the tuple of GROUP BY
10//!    values, and update per-group aggregate state.
11//! 4. Materialise a synthetic per-group row containing
12//!    `[__grp_0..__grp_K, __agg_0..__agg_N]` and rewrite the user's
13//!    SELECT / ORDER BY expressions to reference those synthetic columns
14//!    instead of the originals.
15//! 5. Evaluate the rewritten expressions against the synthetic schema and
16//!    emit results.
17//!
18//! v1.8 implements `count(*)`, `count(expr)`, `sum`, `min`, `max`, `avg`.
19//! NULL semantics follow PG: aggregates skip NULL inputs (except
20//! `count(*)`, which counts rows). `sum(int)` widens to `BigInt`;
21//! `avg(int|bigint)` returns `Float`.
22
23use alloc::boxed::Box;
24use alloc::collections::BTreeMap;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34/// True if this statement should go through the aggregate path.
35pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36    if stmt.group_by.is_some() || stmt.having.is_some() {
37        return true;
38    }
39    for item in &stmt.items {
40        if let SelectItem::Expr { expr, .. } = item
41            && contains_aggregate(expr)
42        {
43            return true;
44        }
45    }
46    for o in &stmt.order_by {
47        if contains_aggregate(&o.expr) {
48            return true;
49        }
50    }
51    if let Some(h) = &stmt.having
52        && contains_aggregate(h)
53    {
54        return true;
55    }
56    false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60    match e {
61        Expr::FunctionCall { name, args } => {
62            is_aggregate_name(name) || args.iter().any(contains_aggregate)
63        }
64        Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
65        Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
66            contains_aggregate(expr)
67        }
68        Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
69        Expr::Extract { source, .. } => contains_aggregate(source),
70        // v4.10 subqueries + v4.12 window functions / Literal /
71        // Column — all non-aggregate leaves from the regular
72        // aggregate planner's POV. Window-bearing projections are
73        // routed to exec_select_with_window before this runs.
74        Expr::ScalarSubquery(_)
75        | Expr::Exists { .. }
76        | Expr::InSubquery { .. }
77        | Expr::WindowFunction { .. }
78        | Expr::Literal(_)
79        | Expr::Placeholder(_)
80        | Expr::Column(_) => false,
81        // v7.10.10 — recurse into array constructor / subscript /
82        // ANY/ALL children. Aggregates inside `ARRAY[SUM(x)]` are
83        // valid PG and must be detected here.
84        Expr::Array(items) => items.iter().any(contains_aggregate),
85        Expr::ArraySubscript { target, index } => {
86            contains_aggregate(target) || contains_aggregate(index)
87        }
88        Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
89        // v7.13.0 — CASE WHEN … END. Recurse into operand,
90        // every (WHEN, THEN) pair, and the ELSE branch.
91        Expr::Case {
92            operand,
93            branches,
94            else_branch,
95        } => {
96            operand.as_deref().is_some_and(contains_aggregate)
97                || branches
98                    .iter()
99                    .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
100                || else_branch.as_deref().is_some_and(contains_aggregate)
101        }
102    }
103}
104
105pub fn is_aggregate_name(name: &str) -> bool {
106    matches!(
107        name.to_ascii_lowercase().as_str(),
108        "count"
109            | "count_star"
110            | "sum"
111            | "min"
112            | "max"
113            | "avg"
114            // v7.17.0 — variadic / collection aggregates. ORM
115            // reports (Hibernate / Rails / Django) emit these in
116            // GROUP BY rollups; pre-7.17 SPG hit "unknown
117            // aggregate".
118            | "string_agg"
119            | "array_agg"
120            // v7.17.0 — boolean aggregates. `every` is SQL-standard
121            // alias for `bool_and`.
122            | "bool_and"
123            | "bool_or"
124            | "every"
125    )
126}
127
128/// Per-aggregate running state.
129#[derive(Debug, Default, Clone)]
130struct AggState {
131    count: i64,
132    sum_int: i64,
133    sum_float: f64,
134    extreme: Option<Value>,
135    use_float: bool,
136    /// v7.17.0 — running collection for string_agg / array_agg.
137    /// Each entry is one row's contribution (NULL preserved as
138    /// `Value::Null`; string_agg's finalize step drops them, but
139    /// array_agg keeps them). Pushing in insertion order matches
140    /// PG behaviour when no `ORDER BY` is given inside the
141    /// aggregate call.
142    items: Vec<Value>,
143    /// v7.17.0 — captured separator for string_agg. PG accepts a
144    /// non-constant separator expression but in practice every
145    /// caller passes a literal; the engine snapshots the last
146    /// non-NULL text it sees, which matches PG's "use the latest
147    /// row's value" behaviour.
148    separator: Option<String>,
149    /// v7.17.0 — running boolean accumulator for bool_and /
150    /// bool_or / every. `None` until the first non-NULL input;
151    /// at finalize None → SQL NULL.
152    bool_acc: Option<bool>,
153}
154
155#[derive(Debug, Clone)]
156struct AggSpec {
157    name: String, // lowercased
158    /// First argument (value expression) for every aggregate
159    /// except `count(*)`. `None` for `count_star`.
160    arg: Option<Expr>,
161    /// v7.17.0 — second argument. Only `string_agg(value, sep)`
162    /// uses it today. `None` for every other aggregate (or for
163    /// `array_agg`, which is single-arg). Carried in the spec so
164    /// per-row evaluation can re-use the same separator
165    /// expression across calls.
166    arg2: Option<Expr>,
167}
168
169/// Output of running the aggregate path. Schema describes one row per
170/// group; rows are not yet ORDER BY-sorted (caller does it).
171#[derive(Debug)]
172pub struct AggResult {
173    pub columns: Vec<ColumnSchema>,
174    pub rows: Vec<Row>,
175}
176
177/// Execute aggregate logic against an already-WHERE-filtered iterator of
178/// rows. `table_alias` is the alias accepted by column resolution.
179#[allow(clippy::too_many_lines)]
180pub fn run(
181    stmt: &SelectStatement,
182    rows: &[&Row],
183    schema_cols: &[ColumnSchema],
184    table_alias: Option<&str>,
185) -> Result<AggResult, EvalError> {
186    let ctx = EvalContext::new(schema_cols, table_alias);
187    let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
188
189    // Collect aggregate sub-expressions across items + order_by.
190    let mut agg_specs: Vec<AggSpec> = Vec::new();
191    for item in &stmt.items {
192        if let SelectItem::Expr { expr, .. } = item {
193            collect_aggregates(expr, &mut agg_specs);
194        }
195    }
196    for o in &stmt.order_by {
197        collect_aggregates(&o.expr, &mut agg_specs);
198    }
199    if let Some(h) = &stmt.having {
200        collect_aggregates(h, &mut agg_specs);
201    }
202    // v7.17.0 — arity validation. The collector tolerates an
203    // arbitrary positional-arg count; here we enforce the
204    // per-aggregate contract so a malformed call (e.g.
205    // `array_agg()` or `string_agg(x)`) surfaces as a SQL error
206    // rather than silently coercing to a degenerate aggregate.
207    validate_agg_arities(stmt, &agg_specs)?;
208
209    // Map group key (vec of values, encoded as canonical string) -> group state.
210    // Order of insertion is preserved via a parallel Vec of keys.
211    let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
212    let mut key_order: Vec<String> = Vec::new();
213    // When there are no GROUP BY exprs *and* there is at least one aggregate,
214    // every row collapses into a single anonymous group keyed by "".
215    if rows.is_empty() && group_exprs.is_empty() {
216        // Single empty-aggregate group: count=0, sum=0, max=NULL, etc.
217        let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
218        groups.insert(String::new(), (Vec::new(), init));
219        key_order.push(String::new());
220    }
221
222    for row in rows {
223        let group_vals: Vec<Value> = group_exprs
224            .iter()
225            .map(|g| eval::eval_expr(g, row, &ctx))
226            .collect::<Result<_, _>>()?;
227        // v7.17.0 Phase 2.5b — case-insensitive group keying.
228        // For each group_expr that's a column reference on a
229        // CaseInsensitive text column, fold the corresponding
230        // value before encoding the key. Display value
231        // (`group_vals`) stays original — only the key folds.
232        let mut key_vals = group_vals.clone();
233        for (i, g) in group_exprs.iter().enumerate() {
234            if matches!(
235                eval::column_collation(g, &ctx),
236                Some(spg_storage::Collation::CaseInsensitive)
237            ) {
238                if let Value::Text(s) = &key_vals[i] {
239                    key_vals[i] = Value::Text(s.to_ascii_lowercase());
240                }
241            }
242        }
243        let key = encode_key(&key_vals);
244        let entry = groups.entry(key.clone()).or_insert_with(|| {
245            key_order.push(key.clone());
246            let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
247            (group_vals.clone(), init)
248        });
249        for (i, spec) in agg_specs.iter().enumerate() {
250            let arg_val = match &spec.arg {
251                None => Value::Bool(true), // count_star: sentinel non-null
252                Some(e) => eval::eval_expr(e, row, &ctx)?,
253            };
254            // v7.17.0 — `string_agg(value, separator)` evaluates the
255            // separator per row but PG treats it as constant; we
256            // pass the per-row value into update_state so a future
257            // varying-separator caller still sees correct output,
258            // even though SPG (like PG) only uses the most recent.
259            let arg2_val = match &spec.arg2 {
260                None => None,
261                Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
262            };
263            update_state(&mut entry.1[i], &spec.name, &arg_val, arg2_val.as_ref())?;
264        }
265    }
266
267    // Build synthetic schema: __grp_0..K then __agg_0..N.
268    let group_types: Vec<DataType> = if rows.is_empty() {
269        // Use Text as a safe stand-in — empty result means schema isn't
270        // observable. Avoids needing to evaluate group exprs on no row.
271        group_exprs.iter().map(|_| DataType::Text).collect()
272    } else {
273        let probe = rows[0];
274        group_exprs
275            .iter()
276            .map(|g| {
277                eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
278            })
279            .collect::<Result<_, _>>()?
280    };
281    let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
282    let mut synth_schema: Vec<ColumnSchema> = Vec::new();
283    for (i, ty) in group_types.iter().enumerate() {
284        synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
285    }
286    for (i, ty) in agg_types.iter().enumerate() {
287        synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
288    }
289
290    // Materialise synthetic rows.
291    let mut synth_rows: Vec<Row> = Vec::new();
292    for k in &key_order {
293        let (gvals, states) = &groups[k];
294        let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
295        values.extend(gvals.iter().cloned());
296        for (i, st) in states.iter().enumerate() {
297            values.push(finalize(&agg_specs[i].name, st));
298        }
299        synth_rows.push(Row::new(values));
300    }
301
302    // Rewrite the user's SELECT items + ORDER BY to reference synthetic
303    // columns. After rewriting, every remaining `Expr::Column` must
304    // resolve against the synthetic schema (i.e. must have been a GROUP
305    // BY expression).
306    let columns: Vec<ColumnSchema> = stmt
307        .items
308        .iter()
309        .map(|item| match item {
310            SelectItem::Wildcard => Err(EvalError::TypeMismatch {
311                detail: "SELECT * with aggregates is not supported".into(),
312            }),
313            SelectItem::Expr { expr, alias } => {
314                let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
315                let name = alias.clone().unwrap_or_else(|| expr.to_string());
316                Ok(ColumnSchema::new(
317                    name,
318                    agg_or_group_type(&rewritten, &synth_schema),
319                    true,
320                ))
321            }
322        })
323        .collect::<Result<_, _>>()?;
324
325    // Project per synthetic row. HAVING filters out groups *before*
326    // we keep the projected row — same semantics as PG: HAVING runs
327    // against the aggregated row (so `HAVING count(*) > 1` works) and
328    // sees only group-by'd columns plus aggregate values.
329    let synth_ctx = EvalContext::new(&synth_schema, None);
330    let having_rewritten = stmt
331        .having
332        .as_ref()
333        .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
334    let mut kept_synth: Vec<Row> = Vec::new();
335    let mut out_rows: Vec<Row> = Vec::new();
336    for srow in synth_rows {
337        if let Some(h) = &having_rewritten {
338            let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
339            if !matches!(cond, Value::Bool(true)) {
340                continue;
341            }
342        }
343        let mut values: Vec<Value> = Vec::with_capacity(columns.len());
344        for item in &stmt.items {
345            if let SelectItem::Expr { expr, .. } = item {
346                let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
347                values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
348            }
349        }
350        kept_synth.push(srow);
351        out_rows.push(Row::new(values));
352    }
353
354    // ORDER BY: evaluate the rewritten order_by against each synth row,
355    // sort, then drop the keys. Limit is applied by the caller.
356    if !stmt.order_by.is_empty() {
357        // v6.4.0 — multi-key ORDER BY on aggregate output. Each key
358        // gets its own rewrite + per-key DESC flag.
359        let rewritten: Vec<Expr> = stmt
360            .order_by
361            .iter()
362            .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
363            .collect();
364        let descs: Vec<bool> = stmt.order_by.iter().map(|o| o.desc).collect();
365        let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
366            .into_iter()
367            .zip(out_rows)
368            .map(|(s, o)| {
369                let mut keys = Vec::with_capacity(rewritten.len());
370                for e in &rewritten {
371                    keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
372                }
373                Ok::<_, EvalError>((keys, o))
374            })
375            .collect::<Result<_, _>>()?;
376        tagged.sort_by(|a, b| {
377            use core::cmp::Ordering;
378            for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
379                let cmp = value_cmp(ka, kb);
380                let cmp = if descs[i] { cmp.reverse() } else { cmp };
381                if cmp != Ordering::Equal {
382                    return cmp;
383                }
384            }
385            Ordering::Equal
386        });
387        out_rows = tagged.into_iter().map(|(_, o)| o).collect();
388    }
389
390    Ok(AggResult {
391        columns,
392        rows: out_rows,
393    })
394}
395
396/// v7.17.0 — walk the statement again to validate the positional
397/// arity of every aggregate call site. Done after AST collection
398/// rather than inside `collect_aggregates` so the collector stays
399/// infallible; callers in `run()` can do a single early-error
400/// exit before any per-row work.
401fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
402    fn walk(e: &Expr) -> Result<(), EvalError> {
403        if let Expr::FunctionCall { name, args } = e {
404            let lower = name.to_ascii_lowercase();
405            let expected: Option<usize> = match lower.as_str() {
406                "count_star" => Some(0),
407                "count" | "sum" | "avg" | "min" | "max" | "array_agg"
408                // v7.17.0 — boolean aggregates also take exactly
409                // one arg. `every` is an alias normalised inside
410                // collect_aggregates / rewrite_expr.
411                | "bool_and" | "bool_or" | "every" => Some(1),
412                "string_agg" => Some(2),
413                _ => None,
414            };
415            if let Some(want) = expected
416                && args.len() != want
417            {
418                return Err(EvalError::TypeMismatch {
419                    detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
420                });
421            }
422            for a in args {
423                walk(a)?;
424            }
425        } else if let Expr::Binary { lhs, rhs, .. } = e {
426            walk(lhs)?;
427            walk(rhs)?;
428        } else if let Expr::Unary { expr, .. }
429        | Expr::Cast { expr, .. }
430        | Expr::IsNull { expr, .. } = e
431        {
432            walk(expr)?;
433        }
434        Ok(())
435    }
436    for item in &stmt.items {
437        if let SelectItem::Expr { expr, .. } = item {
438            walk(expr)?;
439        }
440    }
441    for o in &stmt.order_by {
442        walk(&o.expr)?;
443    }
444    if let Some(h) = &stmt.having {
445        walk(h)?;
446    }
447    Ok(())
448}
449
450fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
451    match e {
452        Expr::FunctionCall { name, args } => {
453            let lower = name.to_ascii_lowercase();
454            if is_aggregate_name(&lower) {
455                let arg = if lower == "count_star" {
456                    None
457                } else {
458                    args.first().cloned()
459                };
460                // v7.17.0 — second positional arg for
461                // `string_agg(value, separator)`. Everything else
462                // ignores it.
463                let arg2 = if lower == "string_agg" {
464                    args.get(1).cloned()
465                } else {
466                    None
467                };
468                // v7.17.0 — `every` is the SQL-standard alias for
469                // `bool_and`; collapse at collection time so
470                // update_state / finalize need only one arm.
471                let canonical = if lower == "every" {
472                    "bool_and".to_string()
473                } else {
474                    lower
475                };
476                let spec = AggSpec {
477                    name: canonical,
478                    arg: arg.clone(),
479                    arg2: arg2.clone(),
480                };
481                if !out
482                    .iter()
483                    .any(|s| s.name == spec.name && s.arg == spec.arg && s.arg2 == spec.arg2)
484                {
485                    out.push(spec);
486                }
487                // Don't recurse into the arg — nested aggregates are
488                // illegal in standard SQL.
489            } else {
490                for a in args {
491                    collect_aggregates(a, out);
492                }
493            }
494        }
495        Expr::Binary { lhs, rhs, .. } => {
496            collect_aggregates(lhs, out);
497            collect_aggregates(rhs, out);
498        }
499        Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
500            collect_aggregates(expr, out);
501        }
502        Expr::Like { expr, pattern, .. } => {
503            collect_aggregates(expr, out);
504            collect_aggregates(pattern, out);
505        }
506        Expr::Extract { source, .. } => collect_aggregates(source, out),
507        // v4.10 subquery + v4.12 window / Literal / Column —
508        // non-recursing leaves for the aggregate collector.
509        Expr::ScalarSubquery(_)
510        | Expr::Exists { .. }
511        | Expr::InSubquery { .. }
512        | Expr::WindowFunction { .. }
513        | Expr::Literal(_)
514        | Expr::Placeholder(_)
515        | Expr::Column(_) => {}
516        // v7.10.10 — recurse into array constructor children +
517        // subscript / ANY/ALL operands.
518        Expr::Array(items) => {
519            for elem in items {
520                collect_aggregates(elem, out);
521            }
522        }
523        Expr::ArraySubscript { target, index } => {
524            collect_aggregates(target, out);
525            collect_aggregates(index, out);
526        }
527        Expr::AnyAll { expr, array, .. } => {
528            collect_aggregates(expr, out);
529            collect_aggregates(array, out);
530        }
531        Expr::Case {
532            operand,
533            branches,
534            else_branch,
535        } => {
536            if let Some(o) = operand {
537                collect_aggregates(o, out);
538            }
539            for (w, t) in branches {
540                collect_aggregates(w, out);
541                collect_aggregates(t, out);
542            }
543            if let Some(e) = else_branch {
544                collect_aggregates(e, out);
545            }
546        }
547    }
548}
549
550fn update_state(
551    st: &mut AggState,
552    name: &str,
553    v: &Value,
554    arg2: Option<&Value>,
555) -> Result<(), EvalError> {
556    let is_null = matches!(v, Value::Null);
557    match name {
558        "count_star" => st.count += 1,
559        "count" => {
560            if !is_null {
561                st.count += 1;
562            }
563        }
564        "sum" | "avg" => {
565            if is_null {
566                return Ok(());
567            }
568            st.count += 1;
569            match v {
570                Value::Int(n) => st.sum_int += i64::from(*n),
571                Value::BigInt(n) => st.sum_int += *n,
572                Value::Float(x) => {
573                    st.use_float = true;
574                    st.sum_float += *x;
575                }
576                other => {
577                    return Err(EvalError::TypeMismatch {
578                        detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
579                    });
580                }
581            }
582        }
583        "min" => {
584            if is_null {
585                return Ok(());
586            }
587            match &st.extreme {
588                None => st.extreme = Some(v.clone()),
589                Some(cur) => {
590                    if value_cmp(v, cur) == core::cmp::Ordering::Less {
591                        st.extreme = Some(v.clone());
592                    }
593                }
594            }
595        }
596        "max" => {
597            if is_null {
598                return Ok(());
599            }
600            match &st.extreme {
601                None => st.extreme = Some(v.clone()),
602                Some(cur) => {
603                    if value_cmp(v, cur) == core::cmp::Ordering::Greater {
604                        st.extreme = Some(v.clone());
605                    }
606                }
607            }
608        }
609        // v7.17.0 — string_agg(value, separator). NULL value is
610        // skipped (PG aggregate-skip-null). Separator captured
611        // from the latest row that flows through; matches PG's
612        // semantics of evaluating the separator per row but using
613        // the last value at finalize time (in practice it's
614        // constant). count is bumped so we can distinguish "empty
615        // group → NULL" from "all-NULL group → NULL".
616        "string_agg" => {
617            if let Some(sep) = arg2
618                && let Value::Text(s) = sep
619            {
620                st.separator = Some(s.clone());
621            }
622            if is_null {
623                return Ok(());
624            }
625            if let Value::Text(s) = v {
626                st.items.push(Value::Text(s.clone()));
627                st.count += 1;
628            } else {
629                return Err(EvalError::TypeMismatch {
630                    detail: format!("string_agg requires text value, got {:?}", v.data_type()),
631                });
632            }
633        }
634        // v7.17.0 — array_agg(value). Unlike string_agg, NULL
635        // elements are KEPT in the array (PG behaviour); the
636        // result is NULL only when ZERO rows fed in. Element type
637        // is locked from the first row's value type; subsequent
638        // rows must match (PG also rejects mixed-type array_agg).
639        "array_agg" => {
640            st.items.push(v.clone());
641            st.count += 1;
642        }
643        // v7.17.0 — bool_and(p): TRUE iff every non-NULL input is
644        // TRUE. NULL skipped; running accumulator stays at TRUE
645        // until the first non-NULL FALSE.
646        "bool_and" => {
647            if is_null {
648                return Ok(());
649            }
650            let b = match v {
651                Value::Bool(b) => *b,
652                other => {
653                    return Err(EvalError::TypeMismatch {
654                        detail: format!("bool_and requires bool, got {:?}", other.data_type()),
655                    });
656                }
657            };
658            st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
659        }
660        // v7.17.0 — bool_or(p): TRUE iff any non-NULL input is
661        // TRUE. NULL skipped.
662        "bool_or" => {
663            if is_null {
664                return Ok(());
665            }
666            let b = match v {
667                Value::Bool(b) => *b,
668                other => {
669                    return Err(EvalError::TypeMismatch {
670                        detail: format!("bool_or requires bool, got {:?}", other.data_type()),
671                    });
672                }
673            };
674            st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
675        }
676        _ => unreachable!("non-aggregate {name} in update_state"),
677    }
678    Ok(())
679}
680
681#[allow(clippy::cast_precision_loss)]
682fn finalize(name: &str, st: &AggState) -> Value {
683    match name {
684        "count" | "count_star" => Value::BigInt(st.count),
685        "sum" => {
686            if st.count == 0 {
687                Value::Null
688            } else if st.use_float {
689                Value::Float(st.sum_float + (st.sum_int as f64))
690            } else {
691                Value::BigInt(st.sum_int)
692            }
693        }
694        "avg" => {
695            if st.count == 0 {
696                Value::Null
697            } else {
698                let total = if st.use_float {
699                    st.sum_float + (st.sum_int as f64)
700                } else {
701                    st.sum_int as f64
702                };
703                Value::Float(total / (st.count as f64))
704            }
705        }
706        "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
707        // v7.17.0 — string_agg: join all collected text items with
708        // the captured separator. Empty / all-NULL group → NULL
709        // (PG semantics).
710        "string_agg" => {
711            if st.items.is_empty() {
712                return Value::Null;
713            }
714            let sep = st.separator.clone().unwrap_or_default();
715            let mut out = String::new();
716            for (i, item) in st.items.iter().enumerate() {
717                if i > 0 {
718                    out.push_str(&sep);
719                }
720                if let Value::Text(s) = item {
721                    out.push_str(s);
722                }
723            }
724            Value::Text(out)
725        }
726        // v7.17.0 — array_agg: collect into a typed array. NULL
727        // elements are preserved per PG. Result type is decided
728        // by the first non-NULL element seen (or Text fallback
729        // when the whole group is NULL — PG would surface the
730        // declared input type, but SPG hasn't yet wired the
731        // aggregate's static input-type from `describe`).
732        "array_agg" => {
733            if st.items.is_empty() {
734                return Value::Null;
735            }
736            let probe = st.items.iter().find(|v| !v.is_null());
737            match probe.and_then(spg_storage::Value::data_type) {
738                Some(DataType::Int) | Some(DataType::SmallInt) => {
739                    let items: Vec<Option<i32>> = st
740                        .items
741                        .iter()
742                        .map(|v| match v {
743                            Value::Int(n) => Some(*n),
744                            Value::SmallInt(n) => Some(i32::from(*n)),
745                            _ => None,
746                        })
747                        .collect();
748                    Value::IntArray(items)
749                }
750                Some(DataType::BigInt) => {
751                    let items: Vec<Option<i64>> = st
752                        .items
753                        .iter()
754                        .map(|v| match v {
755                            Value::BigInt(n) => Some(*n),
756                            _ => None,
757                        })
758                        .collect();
759                    Value::BigIntArray(items)
760                }
761                _ => {
762                    let items: Vec<Option<String>> = st
763                        .items
764                        .iter()
765                        .map(|v| match v {
766                            Value::Text(s) => Some(s.clone()),
767                            Value::Null => None,
768                            other => Some(format!("{other:?}")),
769                        })
770                        .collect();
771                    Value::TextArray(items)
772                }
773            }
774        }
775        // v7.17.0 — bool_and / bool_or finalize: lazy-init pattern
776        // means `None` is exactly "empty group or all-NULL", which
777        // PG surfaces as SQL NULL.
778        "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
779        _ => unreachable!(),
780    }
781}
782
783fn infer_agg_type(spec: &AggSpec) -> DataType {
784    match spec.name.as_str() {
785        // count/count_star are exact integer counts; sum widens to BigInt
786        // and reports as such even for Float input (the value column is
787        // nullable so the wire layer surfaces the Float at runtime).
788        "count" | "count_star" | "sum" => DataType::BigInt,
789        "avg" => DataType::Float,
790        // v7.17.0 — string_agg always returns TEXT.
791        "string_agg" => DataType::Text,
792        // v7.17.0 — array_agg's declared output type can't be
793        // known without inspecting the argument's expression
794        // shape. Default to TextArray; finalize widens to
795        // IntArray / BigIntArray when the actual elements are
796        // numeric. Downstream column metadata reports TextArray
797        // which is the lowest common denominator.
798        "array_agg" => DataType::TextArray,
799        // v7.17.0 — boolean aggregates always return BOOL (nullable
800        // — empty / all-NULL group → NULL).
801        "bool_and" | "bool_or" => DataType::Bool,
802        // min/max: we don't know the input type without probing — default
803        // to Text and let downstream rendering coerce.
804        _ => DataType::Text,
805    }
806}
807
808fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
809    if let Expr::Column(c) = e
810        && let Some(s) = synth.iter().find(|s| s.name == c.name)
811    {
812        return s.ty;
813    }
814    // Compound expression — fall back to Text (matches build_projection
815    // behaviour for non-column expressions in the non-aggregate path).
816    DataType::Text
817}
818
819fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
820    // Match aggregate FunctionCalls first — they sit outside group_by.
821    if let Expr::FunctionCall { name, args } = e {
822        let lower = name.to_ascii_lowercase();
823        if is_aggregate_name(&lower) {
824            let arg = if lower == "count_star" {
825                None
826            } else {
827                args.first().cloned()
828            };
829            // v7.17.0 — match the spec we registered for
830            // string_agg(value, separator) on the full pair.
831            let arg2 = if lower == "string_agg" {
832                args.get(1).cloned()
833            } else {
834                None
835            };
836            // v7.17.0 — `every` collapses into `bool_and` at
837            // collection; mirror that here so the rewrite finds
838            // the matching synth column.
839            let canonical: &str = if lower == "every" {
840                "bool_and"
841            } else {
842                lower.as_str()
843            };
844            for (i, spec) in aggs.iter().enumerate() {
845                if spec.name == canonical && spec.arg == arg && spec.arg2 == arg2 {
846                    return Expr::Column(spg_sql::ast::ColumnName {
847                        qualifier: None,
848                        name: format!("__agg_{i}"),
849                    });
850                }
851            }
852        }
853    }
854    // Match a group_by expression by AST equality.
855    for (i, g) in group_exprs.iter().enumerate() {
856        if g == e {
857            return Expr::Column(spg_sql::ast::ColumnName {
858                qualifier: None,
859                name: format!("__grp_{i}"),
860            });
861        }
862    }
863    // Recurse into children.
864    match e {
865        Expr::Binary { lhs, op, rhs } => Expr::Binary {
866            lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
867            op: *op,
868            rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
869        },
870        Expr::Unary { op, expr } => Expr::Unary {
871            op: *op,
872            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
873        },
874        Expr::Cast { expr, target } => Expr::Cast {
875            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
876            target: *target,
877        },
878        Expr::IsNull { expr, negated } => Expr::IsNull {
879            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
880            negated: *negated,
881        },
882        Expr::FunctionCall { name, args } => Expr::FunctionCall {
883            name: name.clone(),
884            args: args
885                .iter()
886                .map(|a| rewrite_expr(a, group_exprs, aggs))
887                .collect(),
888        },
889        Expr::Like {
890            expr,
891            pattern,
892            negated,
893        } => Expr::Like {
894            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
895            pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
896            negated: *negated,
897        },
898        Expr::Extract { field, source } => Expr::Extract {
899            field: *field,
900            source: Box::new(rewrite_expr(source, group_exprs, aggs)),
901        },
902        // v4.10 subquery + v4.12 window / Literal / Column —
903        // clone-pass (these don't participate in aggregate rewrite).
904        Expr::ScalarSubquery(_)
905        | Expr::Exists { .. }
906        | Expr::InSubquery { .. }
907        | Expr::WindowFunction { .. }
908        | Expr::Literal(_)
909        | Expr::Placeholder(_)
910        | Expr::Column(_) => e.clone(),
911        // v7.10.10 — recurse children for array nodes.
912        Expr::Array(items) => Expr::Array(
913            items
914                .iter()
915                .map(|elem| rewrite_expr(elem, group_exprs, aggs))
916                .collect(),
917        ),
918        Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
919            target: Box::new(rewrite_expr(target, group_exprs, aggs)),
920            index: Box::new(rewrite_expr(index, group_exprs, aggs)),
921        },
922        Expr::AnyAll {
923            expr,
924            op,
925            array,
926            is_any,
927        } => Expr::AnyAll {
928            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
929            op: *op,
930            array: Box::new(rewrite_expr(array, group_exprs, aggs)),
931            is_any: *is_any,
932        },
933        Expr::Case {
934            operand,
935            branches,
936            else_branch,
937        } => Expr::Case {
938            operand: operand
939                .as_deref()
940                .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
941            branches: branches
942                .iter()
943                .map(|(w, t)| {
944                    (
945                        rewrite_expr(w, group_exprs, aggs),
946                        rewrite_expr(t, group_exprs, aggs),
947                    )
948                })
949                .collect(),
950            else_branch: else_branch
951                .as_deref()
952                .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
953        },
954    }
955}
956
957/// Canonical string key for a tuple of group values. Used as map key.
958fn encode_key(vals: &[Value]) -> String {
959    let mut out = String::new();
960    for v in vals {
961        match v {
962            Value::Null => out.push_str("N|"),
963            Value::SmallInt(n) => {
964                out.push('s');
965                out.push_str(&n.to_string());
966                out.push('|');
967            }
968            Value::Int(n) => {
969                out.push('I');
970                out.push_str(&n.to_string());
971                out.push('|');
972            }
973            Value::BigInt(n) => {
974                out.push('B');
975                out.push_str(&n.to_string());
976                out.push('|');
977            }
978            Value::Float(x) => {
979                out.push('F');
980                out.push_str(&x.to_string());
981                out.push('|');
982            }
983            Value::Bool(b) => {
984                out.push(if *b { 'T' } else { 'f' });
985                out.push('|');
986            }
987            Value::Text(s) => {
988                out.push('S');
989                out.push_str(s);
990                out.push('|');
991            }
992            Value::Vector(v) => {
993                out.push('V');
994                for x in v {
995                    out.push_str(&x.to_string());
996                    out.push(',');
997                }
998                out.push('|');
999            }
1000            // v6.0.1: GROUP BY on a `VECTOR(N) USING SQ8` column.
1001            // Two cells with byte-identical `(min, max, bytes)`
1002            // share the same group; equivalence is byte-equality
1003            // (same as f32 grouping today — neither path tries to
1004            // normalise nan/-0).
1005            Value::Sq8Vector(q) => {
1006                out.push('Q');
1007                out.push_str(&q.min.to_string());
1008                out.push('@');
1009                out.push_str(&q.max.to_string());
1010                out.push(':');
1011                for b in &q.bytes {
1012                    out.push_str(&b.to_string());
1013                    out.push(',');
1014                }
1015                out.push('|');
1016            }
1017            // v6.0.3: GROUP BY on a `VECTOR(N) USING HALF` column.
1018            // Byte-equality over the raw u16 bits; matches the SQ8
1019            // path's byte-key model.
1020            Value::HalfVector(h) => {
1021                out.push('H');
1022                for b in &h.bytes {
1023                    out.push_str(&b.to_string());
1024                    out.push(',');
1025                }
1026                out.push('|');
1027            }
1028            Value::Numeric { scaled, scale } => {
1029                out.push('D');
1030                out.push_str(&scaled.to_string());
1031                out.push('@');
1032                out.push_str(&scale.to_string());
1033                out.push('|');
1034            }
1035            Value::Date(d) => {
1036                out.push('d');
1037                out.push_str(&d.to_string());
1038                out.push('|');
1039            }
1040            Value::Timestamp(t) => {
1041                out.push('t');
1042                out.push_str(&t.to_string());
1043                out.push('|');
1044            }
1045            Value::Interval { months, micros } => {
1046                out.push('i');
1047                out.push_str(&months.to_string());
1048                out.push('m');
1049                out.push_str(&micros.to_string());
1050                out.push('|');
1051            }
1052            Value::Json(s) => {
1053                out.push('j');
1054                out.push_str(s);
1055                out.push('|');
1056            }
1057            // v7.5.0 — Value is #[non_exhaustive] for downstream
1058            // forward-compat. Any future variant lacking explicit
1059            // handling here will share a debug-derived group key,
1060            // which is observably wrong but won't crash.
1061            _ => {
1062                out.push('?');
1063                out.push_str(&format!("{v:?}"));
1064                out.push('|');
1065            }
1066        }
1067    }
1068    out
1069}
1070
1071#[allow(clippy::cast_precision_loss)]
1072fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1073    use core::cmp::Ordering::Equal;
1074    match (a, b) {
1075        (Value::Null, Value::Null) => Equal,
1076        (Value::Null, _) => core::cmp::Ordering::Greater, // NULLs last
1077        (_, Value::Null) => core::cmp::Ordering::Less,
1078        (Value::Int(x), Value::Int(y)) => x.cmp(y),
1079        (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1080        (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1081        (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1082        (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1083        (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1084        (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1085        (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1086        (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1087        (Value::Text(x), Value::Text(y)) => x.cmp(y),
1088        (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1089        _ => Equal,
1090    }
1091}