Skip to main content

sochdb_query/sql/
aggregate.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! # SQL Aggregation Executor
19//!
20//! Hash-aggregation operator for `GROUP BY` and aggregate functions.
21//!
22//! Supported aggregates: `COUNT(*)`, `COUNT(col)`, `COUNT(DISTINCT col)`,
23//! `SUM`, `AVG`, `MIN`, `MAX`, `MEDIAN`, `STDDEV` (sample, n-1, matching
24//! R's `sd()` and DuckDB's `stddev`).
25//!
26//! ## Pipeline
27//!
28//! ```text
29//! input rows (post-WHERE)
30//!   └─> group keys evaluated per row ──> hash table of group states
31//!         └─> accumulators updated per row
32//!               └─> finalize: one synthesized row per group
33//!                     └─> HAVING filter ─> ORDER BY ─> OFFSET/LIMIT ─> projection
34//! ```
35//!
36//! Semantics notes:
37//! - NULL inputs are skipped by all aggregates except `COUNT(*)` (SQL standard).
38//! - An ungrouped aggregate over zero rows yields exactly one row
39//!   (`COUNT` = 0, other aggregates NULL); a grouped aggregate over zero
40//!   rows yields zero rows.
41//! - Non-aggregate SELECT columns that are not in `GROUP BY` resolve to the
42//!   first value seen in the group (lenient mode, like SQLite / MySQL with
43//!   `ONLY_FULL_GROUP_BY` disabled).
44
45use super::ast::*;
46use super::bridge::ExecutionResult;
47use super::error::SqlResult;
48use rayon::prelude::*;
49use sochdb_core::SochValue;
50use std::collections::{HashMap, HashSet};
51
52/// Row count above which grouped accumulation runs on the rayon pool.
53const PARALLEL_THRESHOLD: usize = 100_000;
54
55// ============================================================================
56// Aggregate function identification
57// ============================================================================
58
59/// Recognized aggregate functions.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum AggFn {
62    Count,
63    Sum,
64    Avg,
65    Min,
66    Max,
67    Median,
68    Stddev,
69}
70
71impl AggFn {
72    /// Recognize an aggregate function by name (case-insensitive).
73    pub fn from_name(name: &str) -> Option<Self> {
74        match name.to_ascii_uppercase().as_str() {
75            "COUNT" => Some(Self::Count),
76            "SUM" => Some(Self::Sum),
77            "AVG" | "MEAN" => Some(Self::Avg),
78            "MIN" => Some(Self::Min),
79            "MAX" => Some(Self::Max),
80            "MEDIAN" => Some(Self::Median),
81            "STDDEV" | "STDDEV_SAMP" | "STDEV" | "SD" => Some(Self::Stddev),
82            _ => None,
83        }
84    }
85}
86
87/// One aggregate call discovered in the query, e.g. `sum(v1)`.
88#[derive(Debug, Clone)]
89struct AggSpec {
90    /// Canonical key, e.g. `"sum(v1)"` — used to bind HAVING / ORDER BY
91    /// references back to the computed value.
92    key: String,
93    func: AggFn,
94    /// Argument expression (`None` for `COUNT(*)`).
95    arg: Option<Expr>,
96    distinct: bool,
97}
98
99/// Returns true if the SELECT needs the aggregation operator.
100pub fn is_aggregate_query(select: &SelectStmt) -> bool {
101    if !select.group_by.is_empty() {
102        return true;
103    }
104    select
105        .columns
106        .iter()
107        .any(|item| matches!(item, SelectItem::Expr { expr, .. } if contains_aggregate(expr)))
108}
109
110/// Recursively check whether an expression contains an aggregate call.
111fn contains_aggregate(expr: &Expr) -> bool {
112    match expr {
113        Expr::Function(f) => {
114            AggFn::from_name(f.name.name()).is_some()
115                || f.args.iter().any(contains_aggregate)
116        }
117        Expr::BinaryOp { left, right, .. } => {
118            contains_aggregate(left) || contains_aggregate(right)
119        }
120        Expr::UnaryOp { expr, .. } => contains_aggregate(expr),
121        Expr::IsNull { expr, .. } => contains_aggregate(expr),
122        Expr::Case {
123            operand,
124            conditions,
125            else_result,
126        } => {
127            operand.as_deref().map(contains_aggregate).unwrap_or(false)
128                || conditions
129                    .iter()
130                    .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
131                || else_result
132                    .as_deref()
133                    .map(contains_aggregate)
134                    .unwrap_or(false)
135        }
136        _ => false,
137    }
138}
139
140/// Collect all distinct aggregate calls from SELECT, HAVING and ORDER BY.
141fn collect_agg_specs(select: &SelectStmt) -> Vec<AggSpec> {
142    let mut specs: Vec<AggSpec> = Vec::new();
143    let mut seen: HashSet<String> = HashSet::new();
144
145    let walk = |expr: &Expr, specs: &mut Vec<AggSpec>, seen: &mut HashSet<String>| {
146        collect_from_expr(expr, specs, seen);
147    };
148
149    for item in &select.columns {
150        if let SelectItem::Expr { expr, .. } = item {
151            walk(expr, &mut specs, &mut seen);
152        }
153    }
154    if let Some(h) = &select.having {
155        walk(h, &mut specs, &mut seen);
156    }
157    for ob in &select.order_by {
158        walk(&ob.expr, &mut specs, &mut seen);
159    }
160    specs
161}
162
163fn collect_from_expr(expr: &Expr, specs: &mut Vec<AggSpec>, seen: &mut HashSet<String>) {
164    match expr {
165        Expr::Function(f) => {
166            if let Some(func) = AggFn::from_name(f.name.name()) {
167                let arg = f.args.first().cloned();
168                let is_star =
169                    matches!(arg.as_ref(), Some(Expr::Column(c)) if c.column == "*");
170                let arg = if is_star { None } else { arg };
171                let key = render_agg_key(func, arg.as_ref(), f.distinct);
172                if seen.insert(key.clone()) {
173                    specs.push(AggSpec {
174                        key,
175                        func,
176                        arg,
177                        distinct: f.distinct,
178                    });
179                }
180            } else {
181                for a in &f.args {
182                    collect_from_expr(a, specs, seen);
183                }
184            }
185        }
186        Expr::BinaryOp { left, right, .. } => {
187            collect_from_expr(left, specs, seen);
188            collect_from_expr(right, specs, seen);
189        }
190        Expr::UnaryOp { expr, .. } => collect_from_expr(expr, specs, seen),
191        Expr::IsNull { expr, .. } => collect_from_expr(expr, specs, seen),
192        Expr::Case {
193            operand,
194            conditions,
195            else_result,
196        } => {
197            if let Some(op) = operand {
198                collect_from_expr(op, specs, seen);
199            }
200            for (w, t) in conditions {
201                collect_from_expr(w, specs, seen);
202                collect_from_expr(t, specs, seen);
203            }
204            if let Some(e) = else_result {
205                collect_from_expr(e, specs, seen);
206            }
207        }
208        _ => {}
209    }
210}
211
212/// Canonical name for an aggregate call: `sum(v1)`, `count(*)`,
213/// `count(distinct id)`. Lowercased so lookups are case-insensitive.
214fn render_agg_key(func: AggFn, arg: Option<&Expr>, distinct: bool) -> String {
215    let fname = match func {
216        AggFn::Count => "count",
217        AggFn::Sum => "sum",
218        AggFn::Avg => "avg",
219        AggFn::Min => "min",
220        AggFn::Max => "max",
221        AggFn::Median => "median",
222        AggFn::Stddev => "stddev",
223    };
224    let arg_s = match arg {
225        None => "*".to_string(),
226        Some(e) => render_expr_name(e),
227    };
228    if distinct {
229        format!("{}(distinct {})", fname, arg_s)
230    } else {
231        format!("{}({})", fname, arg_s)
232    }
233}
234
235/// Human-readable name for an expression, used for output column naming
236/// and canonical aggregate keys.
237pub fn render_expr_name(expr: &Expr) -> String {
238    match expr {
239        Expr::Column(c) => {
240            if let Some(t) = &c.table {
241                format!("{}.{}", t, c.column)
242            } else {
243                c.column.clone()
244            }
245        }
246        Expr::Literal(Literal::Integer(n)) => n.to_string(),
247        Expr::Literal(Literal::Float(f)) => f.to_string(),
248        Expr::Literal(Literal::String(s)) => format!("'{}'", s),
249        Expr::Literal(Literal::Boolean(b)) => b.to_string(),
250        Expr::Literal(Literal::Null) => "null".to_string(),
251        Expr::Function(f) => {
252            if let Some(func) = AggFn::from_name(f.name.name()) {
253                let arg = f.args.first();
254                let is_star =
255                    matches!(arg, Some(Expr::Column(c)) if c.column == "*");
256                render_agg_key(func, if is_star { None } else { arg }, f.distinct)
257            } else {
258                let args: Vec<String> = f.args.iter().map(render_expr_name).collect();
259                format!("{}({})", f.name.name().to_lowercase(), args.join(", "))
260            }
261        }
262        Expr::BinaryOp { left, op, right } => format!(
263            "{} {} {}",
264            render_expr_name(left),
265            binary_op_symbol(op),
266            render_expr_name(right)
267        ),
268        Expr::UnaryOp { op, expr } => match op {
269            UnaryOperator::Minus => format!("-{}", render_expr_name(expr)),
270            UnaryOperator::Plus => render_expr_name(expr),
271            UnaryOperator::Not => format!("not {}", render_expr_name(expr)),
272            UnaryOperator::BitNot => format!("~{}", render_expr_name(expr)),
273        },
274        _ => "expr".to_string(),
275    }
276}
277
278fn binary_op_symbol(op: &BinaryOperator) -> &'static str {
279    match op {
280        BinaryOperator::Plus => "+",
281        BinaryOperator::Minus => "-",
282        BinaryOperator::Multiply => "*",
283        BinaryOperator::Divide => "/",
284        BinaryOperator::Modulo => "%",
285        BinaryOperator::Eq => "=",
286        BinaryOperator::Ne => "<>",
287        BinaryOperator::Lt => "<",
288        BinaryOperator::Le => "<=",
289        BinaryOperator::Gt => ">",
290        BinaryOperator::Ge => ">=",
291        BinaryOperator::And => "and",
292        BinaryOperator::Or => "or",
293        _ => "?",
294    }
295}
296
297// ============================================================================
298// Scalar expression evaluation (over a materialized row)
299// ============================================================================
300
301/// Evaluate a scalar expression against a row map.
302///
303/// `agg_values`, when provided, resolves aggregate function calls by their
304/// canonical key — used for HAVING / ORDER BY / projection over finalized
305/// group rows.
306fn eval_scalar(
307    expr: &Expr,
308    row: &HashMap<String, SochValue>,
309    params: &[SochValue],
310) -> SochValue {
311    match expr {
312        Expr::Column(c) => {
313            if let Some(t) = &c.table {
314                let qualified = format!("{}.{}", t, c.column);
315                if let Some(v) = row.get(&qualified) {
316                    return v.clone();
317                }
318            }
319            row.get(&c.column).cloned().unwrap_or(SochValue::Null)
320        }
321        Expr::Literal(lit) => literal_to_value(lit),
322        Expr::Placeholder(idx) => params
323            .get((*idx as usize).saturating_sub(1))
324            .cloned()
325            .unwrap_or(SochValue::Null),
326        Expr::Function(f) => {
327            // Aggregate results are pre-bound into the row map under their
328            // canonical key by `finalize_groups`.
329            let key = render_expr_name(&Expr::Function(f.clone()));
330            row.get(&key).cloned().unwrap_or(SochValue::Null)
331        }
332        Expr::BinaryOp { left, op, right } => {
333            let l = eval_scalar(left, row, params);
334            let r = eval_scalar(right, row, params);
335            eval_binary(&l, op, &r)
336        }
337        Expr::UnaryOp { op, expr } => {
338            let v = eval_scalar(expr, row, params);
339            match op {
340                UnaryOperator::Minus => match v {
341                    SochValue::Int(i) => SochValue::Int(-i),
342                    SochValue::Float(f) => SochValue::Float(-f),
343                    _ => SochValue::Null,
344                },
345                UnaryOperator::Plus => v,
346                UnaryOperator::Not => match v {
347                    SochValue::Bool(b) => SochValue::Bool(!b),
348                    _ => SochValue::Null,
349                },
350                UnaryOperator::BitNot => match v {
351                    SochValue::Int(i) => SochValue::Int(!i),
352                    _ => SochValue::Null,
353                },
354            }
355        }
356        Expr::IsNull { expr, negated } => {
357            let v = eval_scalar(expr, row, params);
358            let is_null = v.is_null();
359            SochValue::Bool(if *negated { !is_null } else { is_null })
360        }
361        _ => SochValue::Null,
362    }
363}
364
365fn literal_to_value(lit: &Literal) -> SochValue {
366    match lit {
367        Literal::Integer(i) => SochValue::Int(*i),
368        Literal::Float(f) => SochValue::Float(*f),
369        Literal::String(s) => SochValue::Text(s.clone()),
370        Literal::Boolean(b) => SochValue::Bool(*b),
371        Literal::Null => SochValue::Null,
372        _ => SochValue::Null,
373    }
374}
375
376fn numeric(v: &SochValue) -> Option<f64> {
377    match v {
378        SochValue::Int(i) => Some(*i as f64),
379        SochValue::UInt(u) => Some(*u as f64),
380        SochValue::Float(f) => Some(*f),
381        SochValue::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
382        _ => None,
383    }
384}
385
386fn eval_binary(l: &SochValue, op: &BinaryOperator, r: &SochValue) -> SochValue {
387    use BinaryOperator::*;
388    match op {
389        Plus | Minus | Multiply | Divide | Modulo => {
390            // Integer arithmetic when both sides are ints (except division).
391            if let (SochValue::Int(a), SochValue::Int(b)) = (l, r) {
392                return match op {
393                    Plus => SochValue::Int(a.wrapping_add(*b)),
394                    Minus => SochValue::Int(a.wrapping_sub(*b)),
395                    Multiply => SochValue::Int(a.wrapping_mul(*b)),
396                    Divide => {
397                        if *b == 0 {
398                            SochValue::Null
399                        } else {
400                            SochValue::Float(*a as f64 / *b as f64)
401                        }
402                    }
403                    Modulo => {
404                        if *b == 0 {
405                            SochValue::Null
406                        } else {
407                            SochValue::Int(a % b)
408                        }
409                    }
410                    _ => unreachable!(),
411                };
412            }
413            let (a, b) = match (numeric(l), numeric(r)) {
414                (Some(a), Some(b)) => (a, b),
415                _ => return SochValue::Null,
416            };
417            match op {
418                Plus => SochValue::Float(a + b),
419                Minus => SochValue::Float(a - b),
420                Multiply => SochValue::Float(a * b),
421                Divide => {
422                    if b == 0.0 {
423                        SochValue::Null
424                    } else {
425                        SochValue::Float(a / b)
426                    }
427                }
428                Modulo => {
429                    if b == 0.0 {
430                        SochValue::Null
431                    } else {
432                        SochValue::Float(a % b)
433                    }
434                }
435                _ => unreachable!(),
436            }
437        }
438        Eq | Ne | Lt | Le | Gt | Ge => {
439            if l.is_null() || r.is_null() {
440                return SochValue::Null;
441            }
442            let ord = compare_values(l, r);
443            let b = match op {
444                Eq => ord == std::cmp::Ordering::Equal,
445                Ne => ord != std::cmp::Ordering::Equal,
446                Lt => ord == std::cmp::Ordering::Less,
447                Le => ord != std::cmp::Ordering::Greater,
448                Gt => ord == std::cmp::Ordering::Greater,
449                Ge => ord != std::cmp::Ordering::Less,
450                _ => unreachable!(),
451            };
452            SochValue::Bool(b)
453        }
454        And => match (as_bool(l), as_bool(r)) {
455            (Some(a), Some(b)) => SochValue::Bool(a && b),
456            _ => SochValue::Null,
457        },
458        Or => match (as_bool(l), as_bool(r)) {
459            (Some(a), Some(b)) => SochValue::Bool(a || b),
460            _ => SochValue::Null,
461        },
462        _ => SochValue::Null,
463    }
464}
465
466fn as_bool(v: &SochValue) -> Option<bool> {
467    match v {
468        SochValue::Bool(b) => Some(*b),
469        SochValue::Int(i) => Some(*i != 0),
470        SochValue::Null => None,
471        _ => None,
472    }
473}
474
475/// Total ordering across SochValue for grouping/sorting.
476pub fn compare_values(a: &SochValue, b: &SochValue) -> std::cmp::Ordering {
477    use std::cmp::Ordering;
478    match (numeric(a), numeric(b)) {
479        (Some(x), Some(y)) => return x.partial_cmp(&y).unwrap_or(Ordering::Equal),
480        _ => {}
481    }
482    match (a, b) {
483        (SochValue::Text(x), SochValue::Text(y)) => x.cmp(y),
484        (SochValue::Null, SochValue::Null) => Ordering::Equal,
485        (SochValue::Null, _) => Ordering::Less,
486        (_, SochValue::Null) => Ordering::Greater,
487        _ => Ordering::Equal,
488    }
489}
490
491/// Canonical hash representation of a group-key value.
492/// Normalizes Int/UInt/Float-of-integral so `1`, `1u`, `1.0` group together.
493fn key_repr(v: &SochValue) -> String {
494    match v {
495        SochValue::Null => "\u{0}N".to_string(),
496        SochValue::Int(i) => format!("i{}", i),
497        SochValue::UInt(u) => format!("i{}", u),
498        SochValue::Float(f) => {
499            if f.fract() == 0.0 && f.abs() < 9.0e15 {
500                format!("i{}", *f as i64)
501            } else {
502                format!("f{}", f)
503            }
504        }
505        SochValue::Text(s) => format!("s{}", s),
506        SochValue::Bool(b) => format!("b{}", b),
507        other => format!("{:?}", other),
508    }
509}
510
511// ============================================================================
512// Accumulators
513// ============================================================================
514
515#[derive(Debug)]
516enum Acc {
517    CountStar(u64),
518    Count(u64),
519    CountDistinct(HashSet<String>),
520    /// Sum preserving integer-ness: (int_sum, float_sum, saw_float, saw_any)
521    Sum {
522        int: i64,
523        float: f64,
524        saw_float: bool,
525        saw_any: bool,
526        overflowed: bool,
527    },
528    Avg {
529        sum: f64,
530        n: u64,
531    },
532    Min(Option<SochValue>),
533    Max(Option<SochValue>),
534    Median(Vec<f64>),
535    /// Welford online variance: (n, mean, m2)
536    Stddev {
537        n: u64,
538        mean: f64,
539        m2: f64,
540    },
541}
542
543impl Acc {
544    fn new(spec: &AggSpec) -> Self {
545        match (spec.func, spec.arg.is_some(), spec.distinct) {
546            (AggFn::Count, false, _) => Acc::CountStar(0),
547            (AggFn::Count, true, true) => Acc::CountDistinct(HashSet::new()),
548            (AggFn::Count, true, false) => Acc::Count(0),
549            (AggFn::Sum, _, _) => Acc::Sum {
550                int: 0,
551                float: 0.0,
552                saw_float: false,
553                saw_any: false,
554                overflowed: false,
555            },
556            (AggFn::Avg, _, _) => Acc::Avg { sum: 0.0, n: 0 },
557            (AggFn::Min, _, _) => Acc::Min(None),
558            (AggFn::Max, _, _) => Acc::Max(None),
559            (AggFn::Median, _, _) => Acc::Median(Vec::new()),
560            (AggFn::Stddev, _, _) => Acc::Stddev {
561                n: 0,
562                mean: 0.0,
563                m2: 0.0,
564            },
565        }
566    }
567
568    /// Update with the evaluated argument value (`None` only for COUNT(*)).
569    fn update(&mut self, val: Option<&SochValue>) {
570        match self {
571            Acc::CountStar(n) => *n += 1,
572            Acc::Count(n) => {
573                if let Some(v) = val {
574                    if !v.is_null() {
575                        *n += 1;
576                    }
577                }
578            }
579            Acc::CountDistinct(set) => {
580                if let Some(v) = val {
581                    if !v.is_null() {
582                        set.insert(key_repr(v));
583                    }
584                }
585            }
586            Acc::Sum {
587                int,
588                float,
589                saw_float,
590                saw_any,
591                overflowed,
592            } => {
593                let Some(v) = val else { return };
594                match v {
595                    SochValue::Int(i) => {
596                        *saw_any = true;
597                        match int.checked_add(*i) {
598                            Some(s) => *int = s,
599                            None => *overflowed = true,
600                        }
601                        *float += *i as f64;
602                    }
603                    SochValue::UInt(u) => {
604                        *saw_any = true;
605                        match int.checked_add(*u as i64) {
606                            Some(s) => *int = s,
607                            None => *overflowed = true,
608                        }
609                        *float += *u as f64;
610                    }
611                    SochValue::Float(f) => {
612                        *saw_any = true;
613                        *saw_float = true;
614                        *float += *f;
615                    }
616                    _ => {}
617                }
618            }
619            Acc::Avg { sum, n } => {
620                if let Some(x) = val.and_then(numeric) {
621                    *sum += x;
622                    *n += 1;
623                }
624            }
625            Acc::Min(cur) => {
626                let Some(v) = val else { return };
627                if v.is_null() {
628                    return;
629                }
630                match cur {
631                    None => *cur = Some(v.clone()),
632                    Some(c) => {
633                        if compare_values(v, c) == std::cmp::Ordering::Less {
634                            *cur = Some(v.clone());
635                        }
636                    }
637                }
638            }
639            Acc::Max(cur) => {
640                let Some(v) = val else { return };
641                if v.is_null() {
642                    return;
643                }
644                match cur {
645                    None => *cur = Some(v.clone()),
646                    Some(c) => {
647                        if compare_values(v, c) == std::cmp::Ordering::Greater {
648                            *cur = Some(v.clone());
649                        }
650                    }
651                }
652            }
653            Acc::Median(vals) => {
654                if let Some(x) = val.and_then(numeric) {
655                    vals.push(x);
656                }
657            }
658            Acc::Stddev { n, mean, m2 } => {
659                if let Some(x) = val.and_then(numeric) {
660                    *n += 1;
661                    let delta = x - *mean;
662                    *mean += delta / *n as f64;
663                    let delta2 = x - *mean;
664                    *m2 += delta * delta2;
665                }
666            }
667        }
668    }
669
670    /// Merge a partial accumulator (from a parallel chunk) into self.
671    /// Both must originate from the same `AggSpec`.
672    fn merge(&mut self, other: Acc) {
673        match (self, other) {
674            (Acc::CountStar(a), Acc::CountStar(b)) => *a += b,
675            (Acc::Count(a), Acc::Count(b)) => *a += b,
676            (Acc::CountDistinct(a), Acc::CountDistinct(b)) => a.extend(b),
677            (
678                Acc::Sum {
679                    int,
680                    float,
681                    saw_float,
682                    saw_any,
683                    overflowed,
684                },
685                Acc::Sum {
686                    int: i2,
687                    float: f2,
688                    saw_float: sf2,
689                    saw_any: sa2,
690                    overflowed: of2,
691                },
692            ) => {
693                match int.checked_add(i2) {
694                    Some(s) => *int = s,
695                    None => *overflowed = true,
696                }
697                *float += f2;
698                *saw_float |= sf2;
699                *saw_any |= sa2;
700                *overflowed |= of2;
701            }
702            (Acc::Avg { sum, n }, Acc::Avg { sum: s2, n: n2 }) => {
703                *sum += s2;
704                *n += n2;
705            }
706            (Acc::Min(a), Acc::Min(Some(b))) => match a {
707                None => *a = Some(b),
708                Some(cur) => {
709                    if compare_values(&b, cur) == std::cmp::Ordering::Less {
710                        *a = Some(b);
711                    }
712                }
713            },
714            (Acc::Max(a), Acc::Max(Some(b))) => match a {
715                None => *a = Some(b),
716                Some(cur) => {
717                    if compare_values(&b, cur) == std::cmp::Ordering::Greater {
718                        *a = Some(b);
719                    }
720                }
721            },
722            (Acc::Min(_), Acc::Min(None)) | (Acc::Max(_), Acc::Max(None)) => {}
723            (Acc::Median(a), Acc::Median(b)) => a.extend(b),
724            (
725                Acc::Stddev { n, mean, m2 },
726                Acc::Stddev {
727                    n: nb,
728                    mean: mb,
729                    m2: m2b,
730                },
731            ) => {
732                // Chan et al. parallel variance merge.
733                if nb > 0 {
734                    if *n == 0 {
735                        *n = nb;
736                        *mean = mb;
737                        *m2 = m2b;
738                    } else {
739                        let na = *n as f64;
740                        let nbf = nb as f64;
741                        let delta = mb - *mean;
742                        let total = na + nbf;
743                        *mean += delta * nbf / total;
744                        *m2 += m2b + delta * delta * na * nbf / total;
745                        *n += nb;
746                    }
747                }
748            }
749            _ => unreachable!("mismatched accumulator merge"),
750        }
751    }
752
753    fn finalize(self) -> SochValue {
754        match self {
755            Acc::CountStar(n) | Acc::Count(n) => SochValue::Int(n as i64),
756            Acc::CountDistinct(set) => SochValue::Int(set.len() as i64),
757            Acc::Sum {
758                int,
759                float,
760                saw_float,
761                saw_any,
762                overflowed,
763            } => {
764                if !saw_any {
765                    SochValue::Null
766                } else if saw_float || overflowed {
767                    SochValue::Float(float)
768                } else {
769                    SochValue::Int(int)
770                }
771            }
772            Acc::Avg { sum, n } => {
773                if n == 0 {
774                    SochValue::Null
775                } else {
776                    SochValue::Float(sum / n as f64)
777                }
778            }
779            Acc::Min(v) | Acc::Max(v) => v.unwrap_or(SochValue::Null),
780            Acc::Median(mut vals) => {
781                if vals.is_empty() {
782                    return SochValue::Null;
783                }
784                let mid = vals.len() / 2;
785                if vals.len() % 2 == 1 {
786                    let (_, m, _) =
787                        vals.select_nth_unstable_by(mid, |a, b| {
788                            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
789                        });
790                    SochValue::Float(*m)
791                } else {
792                    // Even count: average the two middle values.
793                    let (lo, hi_first, _) =
794                        vals.select_nth_unstable_by(mid, |a, b| {
795                            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
796                        });
797                    let lo_max = lo
798                        .iter()
799                        .copied()
800                        .fold(f64::NEG_INFINITY, f64::max);
801                    SochValue::Float((lo_max + *hi_first) / 2.0)
802                }
803            }
804            Acc::Stddev { n, m2, .. } => {
805                if n < 2 {
806                    SochValue::Null
807                } else {
808                    SochValue::Float((m2 / (n - 1) as f64).sqrt())
809                }
810            }
811        }
812    }
813}
814
815// ============================================================================
816// The aggregation operator
817// ============================================================================
818
819struct GroupState {
820    key_values: Vec<SochValue>,
821    first_row: HashMap<String, SochValue>,
822    accs: Vec<Acc>,
823}
824
825// ----------------------------------------------------------------------------
826// Fast path: plain-column group keys and aggregate args
827// ----------------------------------------------------------------------------
828
829/// A group-key atom that borrows string data from the input rows —
830/// zero allocations during accumulation lookups.
831#[derive(Debug, Clone, PartialEq, Eq, Hash)]
832enum KeyAtom<'a> {
833    Null,
834    Int(i64),
835    /// Normalized f64 bits (integral floats normalize to `Int`).
836    FBits(u64),
837    Str(&'a str),
838    Bool(bool),
839}
840
841impl<'a> KeyAtom<'a> {
842    fn from_value(v: &'a SochValue) -> Self {
843        match v {
844            SochValue::Null => KeyAtom::Null,
845            SochValue::Int(i) => KeyAtom::Int(*i),
846            SochValue::UInt(u) => KeyAtom::Int(*u as i64),
847            SochValue::Float(f) => {
848                if f.fract() == 0.0 && f.abs() < 9.0e15 {
849                    KeyAtom::Int(*f as i64)
850                } else if f.is_nan() {
851                    KeyAtom::FBits(f64::NAN.to_bits())
852                } else {
853                    KeyAtom::FBits(f.to_bits())
854                }
855            }
856            SochValue::Text(s) => KeyAtom::Str(s.as_str()),
857            SochValue::Bool(b) => KeyAtom::Bool(*b),
858            _ => KeyAtom::Null,
859        }
860    }
861}
862
863#[derive(Debug, Clone, PartialEq, Eq, Hash)]
864enum GroupKey<'a> {
865    Empty,
866    One(KeyAtom<'a>),
867    Many(Vec<KeyAtom<'a>>),
868}
869
870static NULL_VALUE: SochValue = SochValue::Null;
871
872/// Resolve a column reference against a row, trying qualified name first.
873#[inline]
874fn col_get<'r>(row: &'r HashMap<String, SochValue>, col: &PlainCol) -> &'r SochValue {
875    if let Some(q) = &col.qualified {
876        if let Some(v) = row.get(q) {
877            return v;
878        }
879    }
880    row.get(&col.name).unwrap_or(&NULL_VALUE)
881}
882
883/// Pre-resolved plain column: unqualified name + optional "table.col" form.
884struct PlainCol {
885    name: String,
886    qualified: Option<String>,
887}
888
889fn as_plain_col(expr: &Expr) -> Option<PlainCol> {
890    match expr {
891        Expr::Column(c) => Some(PlainCol {
892            name: c.column.clone(),
893            qualified: c.table.as_ref().map(|t| format!("{}.{}", t, c.column)),
894        }),
895        _ => None,
896    }
897}
898
899/// Build the borrowed group key for one row.
900fn make_group_key<'r>(
901    row: &'r HashMap<String, SochValue>,
902    group_cols: &[PlainCol],
903) -> GroupKey<'r> {
904    match group_cols.len() {
905        0 => GroupKey::Empty,
906        1 => GroupKey::One(KeyAtom::from_value(col_get(row, &group_cols[0]))),
907        _ => GroupKey::Many(
908            group_cols
909                .iter()
910                .map(|c| KeyAtom::from_value(col_get(row, c)))
911                .collect(),
912        ),
913    }
914}
915
916/// Try the optimized accumulation path. Applicable when every GROUP BY
917/// expression and every aggregate argument is a plain column reference
918/// (which covers typical analytics queries). Returns group states in
919/// first-seen order (per-chunk order under parallel execution).
920fn accumulate_fast<'a>(
921    select: &SelectStmt,
922    specs: &[AggSpec],
923    rows: &'a [HashMap<String, SochValue>],
924) -> Option<Vec<GroupState>> {
925    // Pre-resolve group-key columns.
926    let group_cols: Vec<PlainCol> = select
927        .group_by
928        .iter()
929        .map(as_plain_col)
930        .collect::<Option<Vec<_>>>()?;
931    // Pre-resolve aggregate argument columns (None = COUNT(*)).
932    let arg_cols: Vec<Option<PlainCol>> = specs
933        .iter()
934        .map(|s| match &s.arg {
935            None => Some(None),
936            Some(e) => as_plain_col(e).map(Some),
937        })
938        .collect::<Option<Vec<_>>>()?;
939
940    let accumulate_chunk = |chunk: &'a [HashMap<String, SochValue>]| -> Vec<(GroupKey<'a>, GroupState)> {
941        let mut order: Vec<(GroupKey<'a>, GroupState)> = Vec::new();
942        let mut index: HashMap<GroupKey<'a>, usize> = HashMap::new();
943        for row in chunk {
944            let key = make_group_key(row, &group_cols);
945            let idx = match index.get(&key) {
946                Some(&i) => i,
947                None => {
948                    let state = GroupState {
949                        key_values: group_cols
950                            .iter()
951                            .map(|c| col_get(row, c).clone())
952                            .collect(),
953                        first_row: row.clone(),
954                        accs: specs.iter().map(Acc::new).collect(),
955                    };
956                    order.push((key.clone(), state));
957                    index.insert(key, order.len() - 1);
958                    order.len() - 1
959                }
960            };
961            let accs = &mut order[idx].1.accs;
962            for (acc, arg) in accs.iter_mut().zip(arg_cols.iter()) {
963                match arg {
964                    None => acc.update(None),
965                    Some(col) => acc.update(Some(col_get(row, col))),
966                }
967            }
968        }
969        order
970    };
971
972    let merged: Vec<(GroupKey<'a>, GroupState)> = if rows.len() >= PARALLEL_THRESHOLD {
973        let n_threads = rayon::current_num_threads().max(1);
974        let chunk_size = (rows.len() / (n_threads * 4)).max(16_384);
975        let partials: Vec<Vec<(GroupKey<'a>, GroupState)>> = rows
976            .par_chunks(chunk_size)
977            .map(accumulate_chunk)
978            .collect();
979        // Merge chunk partials in chunk order.
980        let mut order: Vec<(GroupKey<'a>, GroupState)> = Vec::new();
981        let mut index: HashMap<GroupKey<'a>, usize> = HashMap::new();
982        for partial in partials {
983            for (key, state) in partial {
984                match index.get(&key) {
985                    Some(&i) => {
986                        let dst = &mut order[i].1;
987                        for (a, b) in dst.accs.iter_mut().zip(state.accs.into_iter()) {
988                            a.merge(b);
989                        }
990                    }
991                    None => {
992                        order.push((key.clone(), state));
993                        index.insert(key, order.len() - 1);
994                    }
995                }
996            }
997        }
998        order
999    } else {
1000        accumulate_chunk(rows)
1001    };
1002
1003    Some(merged.into_iter().map(|(_, s)| s).collect())
1004}
1005
1006/// Execute aggregation over materialized input rows (already WHERE-filtered).
1007///
1008/// Handles GROUP BY, all aggregate accumulation, HAVING, ORDER BY,
1009/// OFFSET/LIMIT, and final projection. Returns `ExecutionResult::Rows`.
1010pub fn execute_aggregate(
1011    select: &SelectStmt,
1012    rows: &[HashMap<String, SochValue>],
1013    params: &[SochValue],
1014    limit: Option<usize>,
1015    offset: Option<usize>,
1016) -> SqlResult<ExecutionResult> {
1017    let specs = collect_agg_specs(select);
1018    let grouped = !select.group_by.is_empty();
1019
1020    // ---- accumulate ----
1021    // Fast path: plain-column keys/args, borrowed-key hashing, parallel
1022    // partitioned accumulation. Falls back to the general expression-based
1023    // path for computed keys or computed aggregate arguments.
1024    let mut order: Vec<GroupState> = match accumulate_fast(select, &specs, rows) {
1025        Some(states) => states,
1026        None => {
1027            let mut order: Vec<GroupState> = Vec::new();
1028            let mut index: HashMap<Vec<String>, usize> = HashMap::new();
1029
1030            for row in rows {
1031                let key_values: Vec<SochValue> = select
1032                    .group_by
1033                    .iter()
1034                    .map(|e| eval_scalar(e, row, params))
1035                    .collect();
1036                let hash_key: Vec<String> = key_values.iter().map(key_repr).collect();
1037
1038                let idx = match index.get(&hash_key) {
1039                    Some(&i) => i,
1040                    None => {
1041                        let state = GroupState {
1042                            key_values,
1043                            first_row: row.clone(),
1044                            accs: specs.iter().map(Acc::new).collect(),
1045                        };
1046                        order.push(state);
1047                        index.insert(hash_key, order.len() - 1);
1048                        order.len() - 1
1049                    }
1050                };
1051
1052                let state = &mut order[idx];
1053                for (acc, spec) in state.accs.iter_mut().zip(specs.iter()) {
1054                    match &spec.arg {
1055                        None => acc.update(None),
1056                        Some(arg) => {
1057                            let v = eval_scalar(arg, row, params);
1058                            acc.update(Some(&v));
1059                        }
1060                    }
1061                }
1062            }
1063            order
1064        }
1065    };
1066
1067    // Ungrouped aggregate over zero rows still yields one (empty) group.
1068    if !grouped && order.is_empty() {
1069        order.push(GroupState {
1070            key_values: Vec::new(),
1071            first_row: HashMap::new(),
1072            accs: specs.iter().map(Acc::new).collect(),
1073        });
1074    }
1075
1076    // ---- finalize: synthesize one row per group ----
1077    let group_names: Vec<String> = select.group_by.iter().map(render_expr_name).collect();
1078
1079    let mut out_rows: Vec<HashMap<String, SochValue>> = Vec::with_capacity(order.len());
1080    for state in order {
1081        // Start from the first row of the group so non-aggregate columns
1082        // (lenient mode) and qualified names still resolve.
1083        let mut row = state.first_row;
1084        for (name, val) in group_names.iter().zip(state.key_values.into_iter()) {
1085            row.insert(name.clone(), val);
1086        }
1087        for (spec, acc) in specs.iter().zip(state.accs.into_iter()) {
1088            row.insert(spec.key.clone(), acc.finalize());
1089        }
1090        out_rows.push(row);
1091    }
1092
1093    // ---- HAVING ----
1094    if let Some(having) = &select.having {
1095        out_rows.retain(|row| {
1096            matches!(eval_scalar(having, row, params), SochValue::Bool(true))
1097        });
1098    }
1099
1100    // ---- ORDER BY (may reference aggregates or aliases) ----
1101    if !select.order_by.is_empty() {
1102        // Bind aliases so ORDER BY alias works.
1103        let alias_map: Vec<(String, Expr)> = select
1104            .columns
1105            .iter()
1106            .filter_map(|item| match item {
1107                SelectItem::Expr {
1108                    expr,
1109                    alias: Some(a),
1110                } => Some((a.clone(), expr.clone())),
1111                _ => None,
1112            })
1113            .collect();
1114        for row in &mut out_rows {
1115            for (alias, expr) in &alias_map {
1116                if !row.contains_key(alias) {
1117                    let v = eval_scalar(expr, row, params);
1118                    row.insert(alias.clone(), v);
1119                }
1120            }
1121        }
1122        out_rows.sort_by(|a, b| {
1123            for item in &select.order_by {
1124                let va = eval_scalar(&item.expr, a, params);
1125                let vb = eval_scalar(&item.expr, b, params);
1126                let mut cmp = compare_values(&va, &vb);
1127                if !item.asc {
1128                    cmp = cmp.reverse();
1129                }
1130                if cmp != std::cmp::Ordering::Equal {
1131                    return cmp;
1132                }
1133            }
1134            std::cmp::Ordering::Equal
1135        });
1136    }
1137
1138    // ---- OFFSET / LIMIT ----
1139    if let Some(off) = offset {
1140        if off > 0 {
1141            out_rows.drain(..off.min(out_rows.len()));
1142        }
1143    }
1144    if let Some(lim) = limit {
1145        out_rows.truncate(lim);
1146    }
1147
1148    // ---- projection ----
1149    let mut columns: Vec<String> = Vec::new();
1150    let mut projections: Vec<(String, Expr)> = Vec::new();
1151    for item in &select.columns {
1152        match item {
1153            SelectItem::Wildcard | SelectItem::QualifiedWildcard(_) => {
1154                // SELECT * with GROUP BY: project group keys then aggregates.
1155                for name in &group_names {
1156                    columns.push(name.clone());
1157                    projections.push((
1158                        name.clone(),
1159                        Expr::Column(ColumnRef::new(name.clone())),
1160                    ));
1161                }
1162                for spec in &specs {
1163                    columns.push(spec.key.clone());
1164                    projections.push((
1165                        spec.key.clone(),
1166                        Expr::Column(ColumnRef::new(spec.key.clone())),
1167                    ));
1168                }
1169            }
1170            SelectItem::Expr { expr, alias } => {
1171                let name = alias.clone().unwrap_or_else(|| render_expr_name(expr));
1172                columns.push(name.clone());
1173                projections.push((name, expr.clone()));
1174            }
1175        }
1176    }
1177
1178    let projected: Vec<HashMap<String, SochValue>> = out_rows
1179        .into_iter()
1180        .map(|row| {
1181            let mut out = HashMap::with_capacity(projections.len());
1182            for (name, expr) in &projections {
1183                let v = eval_scalar(expr, &row, params);
1184                out.insert(name.clone(), v);
1185            }
1186            out
1187        })
1188        .collect();
1189
1190    Ok(ExecutionResult::Rows {
1191        columns,
1192        rows: projected,
1193    })
1194}
1195
1196#[cfg(test)]
1197mod tests {
1198    use super::super::bridge::{SqlBridge, SqlConnection};
1199    use super::*;
1200
1201    fn fcall(name: &str, arg: &str) -> Expr {
1202        Expr::Function(FunctionCall {
1203            name: ObjectName::new(name),
1204            args: vec![Expr::Column(ColumnRef::new(arg))],
1205            distinct: false,
1206            filter: None,
1207            over: None,
1208        })
1209    }
1210
1211    #[test]
1212    fn agg_fn_recognition() {
1213        assert_eq!(AggFn::from_name("median"), Some(AggFn::Median));
1214        assert_eq!(AggFn::from_name("STDDEV"), Some(AggFn::Stddev));
1215        assert_eq!(AggFn::from_name("stddev_samp"), Some(AggFn::Stddev));
1216        assert_eq!(AggFn::from_name("upper"), None);
1217    }
1218
1219    #[test]
1220    fn canonical_keys() {
1221        assert_eq!(render_expr_name(&fcall("SUM", "v1")), "sum(v1)");
1222        assert_eq!(render_expr_name(&fcall("Median", "v3")), "median(v3)");
1223    }
1224
1225    // ========================================================================
1226    // End-to-end SQL tests through SqlBridge with an in-memory connection
1227    // ========================================================================
1228
1229    /// In-memory table store implementing SqlConnection for tests.
1230    struct DataConn {
1231        tables: HashMap<String, Vec<HashMap<String, SochValue>>>,
1232    }
1233
1234    impl DataConn {
1235        fn new() -> Self {
1236            Self {
1237                tables: HashMap::new(),
1238            }
1239        }
1240
1241        fn with_table(
1242            mut self,
1243            name: &str,
1244            cols: &[&str],
1245            rows: Vec<Vec<SochValue>>,
1246        ) -> Self {
1247            let rows = rows
1248                .into_iter()
1249                .map(|vals| {
1250                    cols.iter()
1251                        .map(|c| c.to_string())
1252                        .zip(vals.into_iter())
1253                        .collect::<HashMap<_, _>>()
1254                })
1255                .collect();
1256            self.tables.insert(name.to_string(), rows);
1257            self
1258        }
1259    }
1260
1261    impl SqlConnection for DataConn {
1262        fn select(
1263            &self,
1264            table: &str,
1265            _: &[String],
1266            _where_clause: Option<&Expr>,
1267            _: &[OrderByItem],
1268            _: Option<usize>,
1269            _: Option<usize>,
1270            _: &[SochValue],
1271        ) -> SqlResult<ExecutionResult> {
1272            // Tests using the aggregate path don't push WHERE here.
1273            let rows = self.tables.get(table).cloned().unwrap_or_default();
1274            Ok(ExecutionResult::Rows {
1275                columns: vec![],
1276                rows,
1277            })
1278        }
1279        fn insert(
1280            &mut self,
1281            _: &str,
1282            _: Option<&[String]>,
1283            _: &[Vec<Expr>],
1284            _: Option<&OnConflict>,
1285            _: &[SochValue],
1286        ) -> SqlResult<ExecutionResult> {
1287            Ok(ExecutionResult::RowsAffected(0))
1288        }
1289        fn update(
1290            &mut self,
1291            _: &str,
1292            _: &[Assignment],
1293            _: Option<&Expr>,
1294            _: &[SochValue],
1295        ) -> SqlResult<ExecutionResult> {
1296            Ok(ExecutionResult::RowsAffected(0))
1297        }
1298        fn delete(
1299            &mut self,
1300            _: &str,
1301            _: Option<&Expr>,
1302            _: &[SochValue],
1303        ) -> SqlResult<ExecutionResult> {
1304            Ok(ExecutionResult::RowsAffected(0))
1305        }
1306        fn create_table(&mut self, _: &CreateTableStmt) -> SqlResult<ExecutionResult> {
1307            Ok(ExecutionResult::Ok)
1308        }
1309        fn drop_table(&mut self, _: &DropTableStmt) -> SqlResult<ExecutionResult> {
1310            Ok(ExecutionResult::Ok)
1311        }
1312        fn create_index(&mut self, _: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
1313            Ok(ExecutionResult::Ok)
1314        }
1315        fn drop_index(&mut self, _: &DropIndexStmt) -> SqlResult<ExecutionResult> {
1316            Ok(ExecutionResult::Ok)
1317        }
1318        fn alter_table(&mut self, _: &AlterTableStmt) -> SqlResult<ExecutionResult> {
1319            Ok(ExecutionResult::Ok)
1320        }
1321        fn begin(&mut self, _: &BeginStmt) -> SqlResult<ExecutionResult> {
1322            Ok(ExecutionResult::TransactionOk)
1323        }
1324        fn commit(&mut self) -> SqlResult<ExecutionResult> {
1325            Ok(ExecutionResult::TransactionOk)
1326        }
1327        fn rollback(&mut self, _: Option<&str>) -> SqlResult<ExecutionResult> {
1328            Ok(ExecutionResult::TransactionOk)
1329        }
1330        fn table_exists(&self, t: &str) -> SqlResult<bool> {
1331            Ok(self.tables.contains_key(t))
1332        }
1333        fn index_exists(&self, _: &str) -> SqlResult<bool> {
1334            Ok(false)
1335        }
1336        fn scan_all(
1337            &self,
1338            table: &str,
1339            _: &[String],
1340        ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
1341            Ok(self.tables.get(table).cloned().unwrap_or_default())
1342        }
1343        fn eval_join_predicate(
1344            &self,
1345            expr: &Expr,
1346            row: &HashMap<String, SochValue>,
1347            params: &[SochValue],
1348        ) -> Option<bool> {
1349            match eval_scalar(expr, row, params) {
1350                SochValue::Bool(b) => Some(b),
1351                SochValue::Null => Some(false),
1352                _ => None,
1353            }
1354        }
1355    }
1356
1357    fn i(v: i64) -> SochValue {
1358        SochValue::Int(v)
1359    }
1360    fn f(v: f64) -> SochValue {
1361        SochValue::Float(v)
1362    }
1363    fn t(v: &str) -> SochValue {
1364        SochValue::Text(v.to_string())
1365    }
1366
1367    /// db-benchmark-shaped fixture: x(id1 text, id3 text, v1 int, v2 int, v3 float)
1368    fn bench_bridge() -> SqlBridge<DataConn> {
1369        let conn = DataConn::new().with_table(
1370            "x",
1371            &["id1", "id3", "v1", "v2", "v3"],
1372            vec![
1373                vec![t("id001"), t("id0000001"), i(1), i(10), f(1.0)],
1374                vec![t("id001"), t("id0000002"), i(2), i(20), f(2.0)],
1375                vec![t("id002"), t("id0000001"), i(3), i(30), f(3.0)],
1376                vec![t("id002"), t("id0000002"), i(4), i(40), f(4.0)],
1377            ],
1378        );
1379        SqlBridge::new(conn)
1380    }
1381
1382    fn rows_of(result: ExecutionResult) -> Vec<HashMap<String, SochValue>> {
1383        match result {
1384            ExecutionResult::Rows { rows, .. } => rows,
1385            other => panic!("expected rows, got {:?}", other),
1386        }
1387    }
1388
1389    fn get<'a>(row: &'a HashMap<String, SochValue>, k: &str) -> &'a SochValue {
1390        row.get(k)
1391            .unwrap_or_else(|| panic!("column '{}' missing from {:?}", k, row))
1392    }
1393
1394    #[test]
1395    fn groupby_sum_q1_shape() {
1396        // db-benchmark q1: SELECT id1, sum(v1) AS v1 FROM x GROUP BY id1
1397        let mut b = bench_bridge();
1398        let rows = rows_of(
1399            b.execute("SELECT id1, sum(v1) AS v1 FROM x GROUP BY id1 ORDER BY id1")
1400                .unwrap(),
1401        );
1402        assert_eq!(rows.len(), 2);
1403        assert_eq!(get(&rows[0], "id1"), &t("id001"));
1404        assert_eq!(get(&rows[0], "v1"), &i(3));
1405        assert_eq!(get(&rows[1], "id1"), &t("id002"));
1406        assert_eq!(get(&rows[1], "v1"), &i(7));
1407    }
1408
1409    #[test]
1410    fn groupby_multi_key_mean() {
1411        // q4-like: SELECT id1, id3, avg(v1) AS m FROM x GROUP BY id1, id3
1412        let mut b = bench_bridge();
1413        let rows = rows_of(
1414            b.execute(
1415                "SELECT id1, id3, avg(v1) AS m FROM x GROUP BY id1, id3 ORDER BY id1, id3",
1416            )
1417            .unwrap(),
1418        );
1419        assert_eq!(rows.len(), 4);
1420        assert_eq!(get(&rows[0], "m"), &f(1.0));
1421        assert_eq!(get(&rows[3], "m"), &f(4.0));
1422    }
1423
1424    #[test]
1425    fn median_and_stddev() {
1426        // q7-like: SELECT median(v3), stddev(v3) FROM x
1427        // v3 = [1,2,3,4]: median = 2.5, sample sd = sqrt(5/3)
1428        let mut b = bench_bridge();
1429        let rows = rows_of(
1430            b.execute("SELECT median(v3) AS med, stddev(v3) AS sd FROM x")
1431                .unwrap(),
1432        );
1433        assert_eq!(rows.len(), 1);
1434        assert_eq!(get(&rows[0], "med"), &f(2.5));
1435        match get(&rows[0], "sd") {
1436            SochValue::Float(sd) => {
1437                assert!((sd - (5.0f64 / 3.0).sqrt()).abs() < 1e-12, "sd={}", sd)
1438            }
1439            other => panic!("expected float sd, got {:?}", other),
1440        }
1441    }
1442
1443    #[test]
1444    fn median_odd_count() {
1445        let conn = DataConn::new().with_table(
1446            "t",
1447            &["v"],
1448            vec![vec![f(5.0)], vec![f(1.0)], vec![f(3.0)]],
1449        );
1450        let mut b = SqlBridge::new(conn);
1451        let rows = rows_of(b.execute("SELECT median(v) AS m FROM t").unwrap());
1452        assert_eq!(get(&rows[0], "m"), &f(3.0));
1453    }
1454
1455    #[test]
1456    fn range_expression_q9_shape() {
1457        // q9: SELECT id3, max(v1) - min(v2) AS range_v1_v2 FROM x GROUP BY id3
1458        let mut b = bench_bridge();
1459        let rows = rows_of(
1460            b.execute(
1461                "SELECT id3, max(v1) - min(v2) AS range_v1_v2 FROM x GROUP BY id3 ORDER BY id3",
1462            )
1463            .unwrap(),
1464        );
1465        assert_eq!(rows.len(), 2);
1466        // id0000001: max(v1)=3, min(v2)=10 -> -7
1467        assert_eq!(get(&rows[0], "range_v1_v2"), &i(-7));
1468        // id0000002: max(v1)=4, min(v2)=20 -> -16
1469        assert_eq!(get(&rows[1], "range_v1_v2"), &i(-16));
1470    }
1471
1472    #[test]
1473    fn count_star_vs_count_col_with_nulls() {
1474        let conn = DataConn::new().with_table(
1475            "t",
1476            &["g", "v"],
1477            vec![
1478                vec![t("a"), i(1)],
1479                vec![t("a"), SochValue::Null],
1480                vec![t("b"), i(2)],
1481            ],
1482        );
1483        let mut b = SqlBridge::new(conn);
1484        let rows = rows_of(
1485            b.execute(
1486                "SELECT g, count(*) AS n, count(v) AS nv FROM t GROUP BY g ORDER BY g",
1487            )
1488            .unwrap(),
1489        );
1490        assert_eq!(rows.len(), 2);
1491        assert_eq!(get(&rows[0], "n"), &i(2));
1492        assert_eq!(get(&rows[0], "nv"), &i(1));
1493        assert_eq!(get(&rows[1], "n"), &i(1));
1494        assert_eq!(get(&rows[1], "nv"), &i(1));
1495    }
1496
1497    #[test]
1498    fn count_distinct() {
1499        // q6-like: SELECT id3, count(DISTINCT id1) AS u FROM x GROUP BY id3
1500        let mut b = bench_bridge();
1501        let rows = rows_of(
1502            b.execute("SELECT id3, count(DISTINCT id1) AS u FROM x GROUP BY id3 ORDER BY id3")
1503                .unwrap(),
1504        );
1505        assert_eq!(rows.len(), 2);
1506        assert_eq!(get(&rows[0], "u"), &i(2));
1507        assert_eq!(get(&rows[1], "u"), &i(2));
1508    }
1509
1510    #[test]
1511    fn having_filters_groups() {
1512        let mut b = bench_bridge();
1513        let rows = rows_of(
1514            b.execute("SELECT id1, sum(v1) AS s FROM x GROUP BY id1 HAVING sum(v1) > 5")
1515                .unwrap(),
1516        );
1517        assert_eq!(rows.len(), 1);
1518        assert_eq!(get(&rows[0], "id1"), &t("id002"));
1519        assert_eq!(get(&rows[0], "s"), &i(7));
1520    }
1521
1522    #[test]
1523    fn order_by_aggregate_desc_with_limit() {
1524        let mut b = bench_bridge();
1525        let rows = rows_of(
1526            b.execute("SELECT id1, sum(v1) AS s FROM x GROUP BY id1 ORDER BY s DESC LIMIT 1")
1527                .unwrap(),
1528        );
1529        assert_eq!(rows.len(), 1);
1530        assert_eq!(get(&rows[0], "id1"), &t("id002"));
1531    }
1532
1533    #[test]
1534    fn ungrouped_aggregate_over_empty_table() {
1535        let conn = DataConn::new().with_table("e", &["v"], vec![]);
1536        let mut b = SqlBridge::new(conn);
1537        let rows = rows_of(
1538            b.execute("SELECT count(*) AS n, sum(v) AS s FROM e").unwrap(),
1539        );
1540        assert_eq!(rows.len(), 1, "ungrouped agg over empty input = one row");
1541        assert_eq!(get(&rows[0], "n"), &i(0));
1542        assert_eq!(get(&rows[0], "s"), &SochValue::Null);
1543    }
1544
1545    #[test]
1546    fn grouped_aggregate_over_empty_table_yields_no_rows() {
1547        let conn = DataConn::new().with_table("e", &["g", "v"], vec![]);
1548        let mut b = SqlBridge::new(conn);
1549        let rows = rows_of(
1550            b.execute("SELECT g, sum(v) AS s FROM e GROUP BY g").unwrap(),
1551        );
1552        assert!(rows.is_empty());
1553    }
1554
1555    #[test]
1556    fn sum_overflow_promotes_to_float() {
1557        let conn = DataConn::new().with_table(
1558            "t",
1559            &["v"],
1560            vec![vec![i(i64::MAX)], vec![i(i64::MAX)]],
1561        );
1562        let mut b = SqlBridge::new(conn);
1563        let rows = rows_of(b.execute("SELECT sum(v) AS s FROM t").unwrap());
1564        match get(&rows[0], "s") {
1565            SochValue::Float(v) => assert!(*v > 1.8e19),
1566            other => panic!("expected float after overflow, got {:?}", other),
1567        }
1568    }
1569
1570    #[test]
1571    fn aggregate_after_join() {
1572        // join + group: SELECT x.id1, sum(y.w) FROM x JOIN y ON x.id1 = y.id1 GROUP BY x.id1
1573        let conn = DataConn::new()
1574            .with_table(
1575                "a",
1576                &["id", "v"],
1577                vec![
1578                    vec![t("k1"), i(1)],
1579                    vec![t("k1"), i(2)],
1580                    vec![t("k2"), i(3)],
1581                ],
1582            )
1583            .with_table(
1584                "b",
1585                &["id", "w"],
1586                vec![vec![t("k1"), i(10)], vec![t("k2"), i(20)]],
1587            );
1588        let mut br = SqlBridge::new(conn);
1589        let rows = rows_of(
1590            br.execute(
1591                "SELECT a.id, sum(a.v) AS sv, sum(b.w) AS sw \
1592                 FROM a JOIN b ON a.id = b.id GROUP BY a.id ORDER BY a.id",
1593            )
1594            .unwrap(),
1595        );
1596        assert_eq!(rows.len(), 2);
1597        assert_eq!(get(&rows[0], "sv"), &i(3));
1598        assert_eq!(get(&rows[0], "sw"), &i(20)); // 10 joined to both k1 rows
1599        assert_eq!(get(&rows[1], "sv"), &i(3));
1600        assert_eq!(get(&rows[1], "sw"), &i(20));
1601    }
1602
1603    #[test]
1604    fn lowercase_function_names_parse() {
1605        // db-benchmark SQL uses lowercase: sum(v1), median(v3)
1606        let mut b = bench_bridge();
1607        assert!(b.execute("SELECT id1, sum(v1) FROM x GROUP BY id1").is_ok());
1608        assert!(b.execute("SELECT median(v3) FROM x").is_ok());
1609        assert!(b.execute("SELECT stddev(v3) FROM x").is_ok());
1610    }
1611
1612    #[test]
1613    fn parallel_path_matches_reference_computation() {
1614        // 150k rows (> PARALLEL_THRESHOLD) exercising the rayon merge:
1615        // sum, avg, count, median, stddev per group, verified against
1616        // values computed directly in the test.
1617        let n: usize = 150_000;
1618        let groups = 7usize;
1619        let mut data: Vec<Vec<SochValue>> = Vec::with_capacity(n);
1620        for idx in 0..n {
1621            data.push(vec![
1622                t(&format!("g{}", idx % groups)),
1623                f((idx * 31 % 1000) as f64 / 4.0),
1624            ]);
1625        }
1626        // Reference computation.
1627        let mut per_group: Vec<Vec<f64>> = vec![Vec::new(); groups];
1628        for idx in 0..n {
1629            per_group[idx % groups].push((idx * 31 % 1000) as f64 / 4.0);
1630        }
1631
1632        let conn = DataConn::new().with_table("big", &["g", "v"], data);
1633        let mut b = SqlBridge::new(conn);
1634        let rows = rows_of(
1635            b.execute(
1636                "SELECT g, count(*) AS n, sum(v) AS s, avg(v) AS m, \
1637                 median(v) AS med, stddev(v) AS sd FROM big GROUP BY g ORDER BY g",
1638            )
1639            .unwrap(),
1640        );
1641        assert_eq!(rows.len(), groups);
1642
1643        for (gi, row) in rows.iter().enumerate() {
1644            let vals = &per_group[gi];
1645            let cnt = vals.len() as f64;
1646            let sum: f64 = vals.iter().sum();
1647            let mean = sum / cnt;
1648            let var =
1649                vals.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / (cnt - 1.0);
1650            let mut sorted = vals.clone();
1651            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1652            let med = if sorted.len() % 2 == 1 {
1653                sorted[sorted.len() / 2]
1654            } else {
1655                (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
1656            };
1657
1658            assert_eq!(get(row, "g"), &t(&format!("g{}", gi)));
1659            assert_eq!(get(row, "n"), &i(vals.len() as i64));
1660            match get(row, "s") {
1661                SochValue::Float(v) => assert!((v - sum).abs() < 1e-6, "sum"),
1662                other => panic!("sum type {:?}", other),
1663            }
1664            match get(row, "m") {
1665                SochValue::Float(v) => assert!((v - mean).abs() < 1e-9, "mean"),
1666                other => panic!("mean type {:?}", other),
1667            }
1668            match get(row, "med") {
1669                SochValue::Float(v) => assert!((v - med).abs() < 1e-9, "median"),
1670                other => panic!("median type {:?}", other),
1671            }
1672            match get(row, "sd") {
1673                SochValue::Float(v) => {
1674                    assert!((v - var.sqrt()).abs() < 1e-9, "sd {} vs {}", v, var.sqrt())
1675                }
1676                other => panic!("sd type {:?}", other),
1677            }
1678        }
1679    }
1680
1681    #[test]
1682    fn unaliased_aggregate_column_name_is_canonical() {
1683        let mut b = bench_bridge();
1684        let result = b.execute("SELECT id1, sum(v1) FROM x GROUP BY id1").unwrap();
1685        let cols = result.columns().unwrap().clone();
1686        assert!(cols.contains(&"sum(v1)".to_string()), "cols={:?}", cols);
1687    }
1688}