1use crate::soch_ql::SochValue;
8use crate::sql::ast::Expr;
9use super::eval::{eval_expr, compare_values};
10use super::node::PlanNode;
11use super::types::{Row, Schema, ColumnMeta};
12use sochdb_core::Result;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum AggFunc {
18 Count,
19 CountDistinct,
20 Sum,
21 Avg,
22 Min,
23 Max,
24}
25
26#[derive(Debug, Clone)]
28pub struct AggDef {
29 pub func: AggFunc,
31 pub expr: Option<Expr>,
33 pub alias: String,
35}
36
37struct Accumulator {
39 func: AggFunc,
40 count: u64,
41 sum_int: i64,
42 sum_float: f64,
43 is_float: bool,
44 min_val: Option<SochValue>,
45 max_val: Option<SochValue>,
46 distinct_set: Option<Vec<SochValue>>,
47}
48
49impl Accumulator {
50 fn new(func: &AggFunc) -> Self {
51 Self {
52 func: func.clone(),
53 count: 0,
54 sum_int: 0,
55 sum_float: 0.0,
56 is_float: false,
57 min_val: None,
58 max_val: None,
59 distinct_set: if matches!(func, AggFunc::CountDistinct) {
60 Some(Vec::new())
61 } else {
62 None
63 },
64 }
65 }
66
67 fn accumulate(&mut self, val: &SochValue) {
68 if matches!(val, SochValue::Null) {
70 if matches!(self.func, AggFunc::Count) {
73 return;
75 }
76 return;
77 }
78
79 match self.func {
80 AggFunc::Count => {
81 self.count += 1;
82 }
83 AggFunc::CountDistinct => {
84 if let Some(set) = &mut self.distinct_set {
85 let already = set.iter().any(|v| {
86 compare_values(v, val) == Some(std::cmp::Ordering::Equal)
87 });
88 if !already {
89 set.push(val.clone());
90 }
91 }
92 }
93 AggFunc::Sum => {
94 match val {
95 SochValue::Int(i) => {
96 if self.is_float {
97 self.sum_float += *i as f64;
98 } else {
99 self.sum_int += i;
100 }
101 }
102 SochValue::UInt(u) => {
103 if self.is_float {
104 self.sum_float += *u as f64;
105 } else {
106 self.sum_int += *u as i64;
107 }
108 }
109 SochValue::Float(f) => {
110 if !self.is_float {
111 self.sum_float = self.sum_int as f64;
112 self.is_float = true;
113 }
114 self.sum_float += f;
115 }
116 _ => {}
117 }
118 self.count += 1;
119 }
120 AggFunc::Avg => {
121 match val {
122 SochValue::Int(i) => self.sum_float += *i as f64,
123 SochValue::UInt(u) => self.sum_float += *u as f64,
124 SochValue::Float(f) => self.sum_float += f,
125 _ => {}
126 }
127 self.count += 1;
128 }
129 AggFunc::Min => {
130 let update = match &self.min_val {
131 None => true,
132 Some(current) => {
133 compare_values(val, current) == Some(std::cmp::Ordering::Less)
134 }
135 };
136 if update {
137 self.min_val = Some(val.clone());
138 }
139 }
140 AggFunc::Max => {
141 let update = match &self.max_val {
142 None => true,
143 Some(current) => {
144 compare_values(val, current) == Some(std::cmp::Ordering::Greater)
145 }
146 };
147 if update {
148 self.max_val = Some(val.clone());
149 }
150 }
151 }
152 }
153
154 fn finalize(&self) -> SochValue {
155 match self.func {
156 AggFunc::Count => SochValue::Int(self.count as i64),
157 AggFunc::CountDistinct => {
158 SochValue::Int(
159 self.distinct_set.as_ref().map_or(0, |s| s.len()) as i64,
160 )
161 }
162 AggFunc::Sum => {
163 if self.count == 0 {
164 SochValue::Null
165 } else if self.is_float {
166 SochValue::Float(self.sum_float)
167 } else {
168 SochValue::Int(self.sum_int)
169 }
170 }
171 AggFunc::Avg => {
172 if self.count == 0 {
173 SochValue::Null
174 } else {
175 SochValue::Float(self.sum_float / self.count as f64)
176 }
177 }
178 AggFunc::Min => self.min_val.clone().unwrap_or(SochValue::Null),
179 AggFunc::Max => self.max_val.clone().unwrap_or(SochValue::Null),
180 }
181 }
182}
183
184#[derive(Debug, Clone, PartialEq, Eq, Hash)]
186struct GroupKey(Vec<GroupVal>);
187
188#[derive(Debug, Clone, PartialEq, Eq, Hash)]
189enum GroupVal {
190 Null,
191 Bool(bool),
192 Int(i64),
193 UInt(u64),
194 Text(String),
195 Other(String),
196}
197
198impl From<&SochValue> for GroupVal {
199 fn from(v: &SochValue) -> Self {
200 match v {
201 SochValue::Null => GroupVal::Null,
202 SochValue::Bool(b) => GroupVal::Bool(*b),
203 SochValue::Int(i) => GroupVal::Int(*i),
204 SochValue::UInt(u) => GroupVal::UInt(*u),
205 SochValue::Text(s) => GroupVal::Text(s.clone()),
206 other => GroupVal::Other(format!("{:?}", other)),
207 }
208 }
209}
210
211struct GroupState {
213 key_values: Vec<SochValue>,
214 accumulators: Vec<Accumulator>,
215}
216
217pub struct HashAggregateNode {
227 input: Box<dyn PlanNode>,
228 group_by_exprs: Vec<Expr>,
229 agg_defs: Vec<AggDef>,
230 output_schema: Schema,
231 groups: Option<Vec<Row>>,
233 pos: usize,
234 is_global: bool,
236}
237
238impl HashAggregateNode {
239 pub fn new(
240 input: Box<dyn PlanNode>,
241 group_by_exprs: Vec<Expr>,
242 agg_defs: Vec<AggDef>,
243 ) -> Self {
244 let is_global = group_by_exprs.is_empty();
245
246 let mut cols: Vec<ColumnMeta> = group_by_exprs
248 .iter()
249 .map(|e| {
250 let name = match e {
251 Expr::Column(c) => c.column.clone(),
252 _ => format!("{:?}", e),
253 };
254 ColumnMeta::new(name)
255 })
256 .collect();
257 for ad in &agg_defs {
258 cols.push(ColumnMeta::new(ad.alias.clone()));
259 }
260 let output_schema = Schema::new(cols);
261
262 Self {
263 input,
264 group_by_exprs,
265 agg_defs,
266 output_schema,
267 groups: None,
268 pos: 0,
269 is_global,
270 }
271 }
272
273 fn materialize(&mut self) -> Result<()> {
274 if self.groups.is_some() {
275 return Ok(());
276 }
277
278 let input_schema = self.input.schema().clone();
279 let mut group_map: HashMap<GroupKey, GroupState> = HashMap::new();
280 let mut group_order: Vec<GroupKey> = Vec::new(); let has_count_star: Vec<bool> = self.agg_defs.iter().map(|ad| {
284 matches!(ad.func, AggFunc::Count) && ad.expr.is_none()
285 }).collect();
286
287 while let Some(row) = self.input.next()? {
288 let key_values: Vec<SochValue> = self
290 .group_by_exprs
291 .iter()
292 .map(|e| eval_expr(e, &row, &input_schema).unwrap_or(SochValue::Null))
293 .collect();
294
295 let group_key = GroupKey(key_values.iter().map(GroupVal::from).collect());
296
297 let state = group_map.entry(group_key.clone()).or_insert_with(|| {
298 group_order.push(group_key.clone());
299 GroupState {
300 key_values: key_values.clone(),
301 accumulators: self
302 .agg_defs
303 .iter()
304 .map(|ad| Accumulator::new(&ad.func))
305 .collect(),
306 }
307 });
308
309 for (i, ad) in self.agg_defs.iter().enumerate() {
311 if has_count_star[i] {
312 state.accumulators[i].count += 1;
314 } else if let Some(expr) = &ad.expr {
315 let val = eval_expr(expr, &row, &input_schema)?;
316 state.accumulators[i].accumulate(&val);
317 }
318 }
319 }
320
321 if self.is_global && group_map.is_empty() {
323 let mut row: Row = Vec::new();
324 for ad in &self.agg_defs {
325 let acc = Accumulator::new(&ad.func);
326 row.push(acc.finalize());
327 }
328 self.groups = Some(vec![row]);
329 return Ok(());
330 }
331
332 let mut result = Vec::with_capacity(group_order.len());
334 for gk in &group_order {
335 if let Some(state) = group_map.get(gk) {
336 let mut row: Row = state.key_values.clone();
337 for acc in &state.accumulators {
338 row.push(acc.finalize());
339 }
340 result.push(row);
341 }
342 }
343
344 self.groups = Some(result);
345 Ok(())
346 }
347}
348
349impl PlanNode for HashAggregateNode {
350 fn schema(&self) -> &Schema {
351 &self.output_schema
352 }
353
354 fn next(&mut self) -> Result<Option<Row>> {
355 self.materialize()?;
356
357 if let Some(groups) = &self.groups {
358 if self.pos < groups.len() {
359 let row = groups[self.pos].clone();
360 self.pos += 1;
361 Ok(Some(row))
362 } else {
363 Ok(None)
364 }
365 } else {
366 Ok(None)
367 }
368 }
369
370 fn reset(&mut self) -> Result<()> {
371 self.groups = None;
372 self.pos = 0;
373 self.input.reset()
374 }
375}