Skip to main content

sochdb_query/executor/
aggregate.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2
3//! Hash aggregate operator (GROUP BY + aggregate functions).
4//!
5//! Supports: COUNT, SUM, AVG, MIN, MAX, COUNT(DISTINCT ...)
6
7use crate::soch_ql::SochValue;
8use crate::sql::ast::Expr;
9use super::eval::{eval_expr, compare_values};
10use super::node::PlanNode;
11use super::types::{Row, Schema, ColumnMeta};
12use sochdb_core::Result;
13use std::collections::HashMap;
14
15/// Aggregate function types.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum AggFunc {
18    Count,
19    CountDistinct,
20    Sum,
21    Avg,
22    Min,
23    Max,
24}
25
26/// Definition of an aggregate computation.
27#[derive(Debug, Clone)]
28pub struct AggDef {
29    /// Function type.
30    pub func: AggFunc,
31    /// Expression to aggregate (None for COUNT(*)).
32    pub expr: Option<Expr>,
33    /// Output column alias.
34    pub alias: String,
35}
36
37/// Internal accumulator for one aggregate function.
38struct Accumulator {
39    func: AggFunc,
40    count: u64,
41    sum_int: i64,
42    sum_float: f64,
43    is_float: bool,
44    min_val: Option<SochValue>,
45    max_val: Option<SochValue>,
46    distinct_set: Option<Vec<SochValue>>,
47}
48
49impl Accumulator {
50    fn new(func: &AggFunc) -> Self {
51        Self {
52            func: func.clone(),
53            count: 0,
54            sum_int: 0,
55            sum_float: 0.0,
56            is_float: false,
57            min_val: None,
58            max_val: None,
59            distinct_set: if matches!(func, AggFunc::CountDistinct) {
60                Some(Vec::new())
61            } else {
62                None
63            },
64        }
65    }
66
67    fn accumulate(&mut self, val: &SochValue) {
68        // Skip NULLs for all aggregates except COUNT(*)
69        if matches!(val, SochValue::Null) {
70            // COUNT(*) still counts NULLs — handled by the caller
71            // passing SochValue::Bool(true) for count(*)
72            if matches!(self.func, AggFunc::Count) {
73                // COUNT(*) — caller decides, regular COUNT(col) should skip
74                return;
75            }
76            return;
77        }
78
79        match self.func {
80            AggFunc::Count => {
81                self.count += 1;
82            }
83            AggFunc::CountDistinct => {
84                if let Some(set) = &mut self.distinct_set {
85                    let already = set.iter().any(|v| {
86                        compare_values(v, val) == Some(std::cmp::Ordering::Equal)
87                    });
88                    if !already {
89                        set.push(val.clone());
90                    }
91                }
92            }
93            AggFunc::Sum => {
94                match val {
95                    SochValue::Int(i) => {
96                        if self.is_float {
97                            self.sum_float += *i as f64;
98                        } else {
99                            self.sum_int += i;
100                        }
101                    }
102                    SochValue::UInt(u) => {
103                        if self.is_float {
104                            self.sum_float += *u as f64;
105                        } else {
106                            self.sum_int += *u as i64;
107                        }
108                    }
109                    SochValue::Float(f) => {
110                        if !self.is_float {
111                            self.sum_float = self.sum_int as f64;
112                            self.is_float = true;
113                        }
114                        self.sum_float += f;
115                    }
116                    _ => {}
117                }
118                self.count += 1;
119            }
120            AggFunc::Avg => {
121                match val {
122                    SochValue::Int(i) => self.sum_float += *i as f64,
123                    SochValue::UInt(u) => self.sum_float += *u as f64,
124                    SochValue::Float(f) => self.sum_float += f,
125                    _ => {}
126                }
127                self.count += 1;
128            }
129            AggFunc::Min => {
130                let update = match &self.min_val {
131                    None => true,
132                    Some(current) => {
133                        compare_values(val, current) == Some(std::cmp::Ordering::Less)
134                    }
135                };
136                if update {
137                    self.min_val = Some(val.clone());
138                }
139            }
140            AggFunc::Max => {
141                let update = match &self.max_val {
142                    None => true,
143                    Some(current) => {
144                        compare_values(val, current) == Some(std::cmp::Ordering::Greater)
145                    }
146                };
147                if update {
148                    self.max_val = Some(val.clone());
149                }
150            }
151        }
152    }
153
154    fn finalize(&self) -> SochValue {
155        match self.func {
156            AggFunc::Count => SochValue::Int(self.count as i64),
157            AggFunc::CountDistinct => {
158                SochValue::Int(
159                    self.distinct_set.as_ref().map_or(0, |s| s.len()) as i64,
160                )
161            }
162            AggFunc::Sum => {
163                if self.count == 0 {
164                    SochValue::Null
165                } else if self.is_float {
166                    SochValue::Float(self.sum_float)
167                } else {
168                    SochValue::Int(self.sum_int)
169                }
170            }
171            AggFunc::Avg => {
172                if self.count == 0 {
173                    SochValue::Null
174                } else {
175                    SochValue::Float(self.sum_float / self.count as f64)
176                }
177            }
178            AggFunc::Min => self.min_val.clone().unwrap_or(SochValue::Null),
179            AggFunc::Max => self.max_val.clone().unwrap_or(SochValue::Null),
180        }
181    }
182}
183
184/// Hash-key for GROUP BY: tuple of group-by values.
185#[derive(Debug, Clone, PartialEq, Eq, Hash)]
186struct GroupKey(Vec<GroupVal>);
187
188#[derive(Debug, Clone, PartialEq, Eq, Hash)]
189enum GroupVal {
190    Null,
191    Bool(bool),
192    Int(i64),
193    UInt(u64),
194    Text(String),
195    Other(String),
196}
197
198impl From<&SochValue> for GroupVal {
199    fn from(v: &SochValue) -> Self {
200        match v {
201            SochValue::Null => GroupVal::Null,
202            SochValue::Bool(b) => GroupVal::Bool(*b),
203            SochValue::Int(i) => GroupVal::Int(*i),
204            SochValue::UInt(u) => GroupVal::UInt(*u),
205            SochValue::Text(s) => GroupVal::Text(s.clone()),
206            other => GroupVal::Other(format!("{:?}", other)),
207        }
208    }
209}
210
211/// Group state: group-by values + accumulators.
212struct GroupState {
213    key_values: Vec<SochValue>,
214    accumulators: Vec<Accumulator>,
215}
216
217/// Hash aggregate operator.
218///
219/// Materializes all input, groups by key expressions, computes aggregates,
220/// then emits one row per group.
221///
222/// ```text
223/// HashAggregate(group_by=[dept], aggs=[COUNT(*), AVG(salary)])
224///   └── input
225/// ```
226pub struct HashAggregateNode {
227    input: Box<dyn PlanNode>,
228    group_by_exprs: Vec<Expr>,
229    agg_defs: Vec<AggDef>,
230    output_schema: Schema,
231    /// Materialized groups (lazily computed).
232    groups: Option<Vec<Row>>,
233    pos: usize,
234    /// Whether this is a global aggregate (no GROUP BY).
235    is_global: bool,
236}
237
238impl HashAggregateNode {
239    pub fn new(
240        input: Box<dyn PlanNode>,
241        group_by_exprs: Vec<Expr>,
242        agg_defs: Vec<AggDef>,
243    ) -> Self {
244        let is_global = group_by_exprs.is_empty();
245
246        // Build output schema: group-by columns + aggregate columns
247        let mut cols: Vec<ColumnMeta> = group_by_exprs
248            .iter()
249            .map(|e| {
250                let name = match e {
251                    Expr::Column(c) => c.column.clone(),
252                    _ => format!("{:?}", e),
253                };
254                ColumnMeta::new(name)
255            })
256            .collect();
257        for ad in &agg_defs {
258            cols.push(ColumnMeta::new(ad.alias.clone()));
259        }
260        let output_schema = Schema::new(cols);
261
262        Self {
263            input,
264            group_by_exprs,
265            agg_defs,
266            output_schema,
267            groups: None,
268            pos: 0,
269            is_global,
270        }
271    }
272
273    fn materialize(&mut self) -> Result<()> {
274        if self.groups.is_some() {
275            return Ok(());
276        }
277
278        let input_schema = self.input.schema().clone();
279        let mut group_map: HashMap<GroupKey, GroupState> = HashMap::new();
280        let mut group_order: Vec<GroupKey> = Vec::new(); // Preserve insertion order
281
282        // Count(*) tracking
283        let has_count_star: Vec<bool> = self.agg_defs.iter().map(|ad| {
284            matches!(ad.func, AggFunc::Count) && ad.expr.is_none()
285        }).collect();
286
287        while let Some(row) = self.input.next()? {
288            // Evaluate group-by keys
289            let key_values: Vec<SochValue> = self
290                .group_by_exprs
291                .iter()
292                .map(|e| eval_expr(e, &row, &input_schema).unwrap_or(SochValue::Null))
293                .collect();
294
295            let group_key = GroupKey(key_values.iter().map(GroupVal::from).collect());
296
297            let state = group_map.entry(group_key.clone()).or_insert_with(|| {
298                group_order.push(group_key.clone());
299                GroupState {
300                    key_values: key_values.clone(),
301                    accumulators: self
302                        .agg_defs
303                        .iter()
304                        .map(|ad| Accumulator::new(&ad.func))
305                        .collect(),
306                }
307            });
308
309            // Accumulate values
310            for (i, ad) in self.agg_defs.iter().enumerate() {
311                if has_count_star[i] {
312                    // COUNT(*) — count every row
313                    state.accumulators[i].count += 1;
314                } else if let Some(expr) = &ad.expr {
315                    let val = eval_expr(expr, &row, &input_schema)?;
316                    state.accumulators[i].accumulate(&val);
317                }
318            }
319        }
320
321        // Handle global aggregate with no input rows
322        if self.is_global && group_map.is_empty() {
323            let mut row: Row = Vec::new();
324            for ad in &self.agg_defs {
325                let acc = Accumulator::new(&ad.func);
326                row.push(acc.finalize());
327            }
328            self.groups = Some(vec![row]);
329            return Ok(());
330        }
331
332        // Build output rows in insertion order
333        let mut result = Vec::with_capacity(group_order.len());
334        for gk in &group_order {
335            if let Some(state) = group_map.get(gk) {
336                let mut row: Row = state.key_values.clone();
337                for acc in &state.accumulators {
338                    row.push(acc.finalize());
339                }
340                result.push(row);
341            }
342        }
343
344        self.groups = Some(result);
345        Ok(())
346    }
347}
348
349impl PlanNode for HashAggregateNode {
350    fn schema(&self) -> &Schema {
351        &self.output_schema
352    }
353
354    fn next(&mut self) -> Result<Option<Row>> {
355        self.materialize()?;
356
357        if let Some(groups) = &self.groups {
358            if self.pos < groups.len() {
359                let row = groups[self.pos].clone();
360                self.pos += 1;
361                Ok(Some(row))
362            } else {
363                Ok(None)
364            }
365        } else {
366            Ok(None)
367        }
368    }
369
370    fn reset(&mut self) -> Result<()> {
371        self.groups = None;
372        self.pos = 0;
373        self.input.reset()
374    }
375}