1use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
11use crate::data::query_engine::QueryEngine;
12use crate::sql::aggregates::contains_aggregate;
13use crate::sql::parser::ast::{SelectItem, SqlExpression};
14use tracing::debug;
15
16#[derive(Debug, Clone)]
18pub struct GroupByPhaseInfo {
19 pub total_rows: usize,
20 pub num_groups: usize,
21 pub num_expressions: usize,
22 pub phase1_cardinality_estimation: Duration,
23 pub phase2_key_building: Duration,
24 pub phase2_expression_evaluation: Duration,
25 pub phase3_dataview_creation: Duration,
26 pub phase4_aggregation: Duration,
27 pub phase4_having_evaluation: Duration,
28 pub groups_filtered_by_having: usize,
29 pub total_time: Duration,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct GroupKey(pub Vec<DataValue>);
35
36pub trait GroupByExpressions {
38 fn group_by_expressions(
40 &self,
41 view: DataView,
42 group_by_exprs: &[SqlExpression],
43 ) -> Result<FxHashMap<GroupKey, DataView>>;
44
45 fn apply_group_by_expressions(
47 &self,
48 view: DataView,
49 group_by_exprs: &[SqlExpression],
50 select_items: &[SelectItem],
51 having: Option<&SqlExpression>,
52 _case_insensitive: bool,
53 date_notation: String,
54 ) -> Result<(DataView, GroupByPhaseInfo)>;
55}
56
57impl GroupByExpressions for QueryEngine {
58 fn group_by_expressions(
59 &self,
60 view: DataView,
61 group_by_exprs: &[SqlExpression],
62 ) -> Result<FxHashMap<GroupKey, DataView>> {
63 use std::time::Instant;
64 let start = Instant::now();
65
66 let phase1_start = Instant::now();
68 let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
69 let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
70 let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
71 FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
72 let phase1_time = phase1_start.elapsed();
73 debug!(
74 "GROUP BY Phase 1 (cardinality estimation): {:?}, estimated {} groups",
75 phase1_time, estimated_groups
76 );
77
78 let phase2_start = Instant::now();
80 let visible_rows = view.get_visible_rows();
81 let total_rows = visible_rows.len();
82 debug!("GROUP BY Phase 2 starting: processing {} rows", total_rows);
83
84 let mut evaluator = ArithmeticEvaluator::new(view.source());
86
87 let mut key_values = Vec::with_capacity(group_by_exprs.len());
89
90 for row_idx in visible_rows.iter().copied() {
91 key_values.clear();
93
94 for expr in group_by_exprs {
96 let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
97 key_values.push(value);
98 }
99
100 let key = GroupKey(key_values.clone()); group_rows.entry(key).or_default().push(row_idx);
102 }
103 let phase2_time = phase2_start.elapsed();
104 debug!(
105 "GROUP BY Phase 2 (expression evaluation & key building): {:?}, created {} unique keys",
106 phase2_time,
107 group_rows.len()
108 );
109
110 let phase3_start = Instant::now();
112 for (key, rows) in group_rows {
113 let mut group_view = DataView::new(view.source_arc());
114 group_view = group_view.with_rows(rows);
115 groups.insert(key, group_view);
116 }
117 let phase3_time = phase3_start.elapsed();
118 debug!("GROUP BY Phase 3 (DataView creation): {:?}", phase3_time);
119
120 let total_time = start.elapsed();
121 debug!(
122 "GROUP BY Total time: {:?} (P1: {:?}, P2: {:?}, P3: {:?})",
123 total_time, phase1_time, phase2_time, phase3_time
124 );
125
126 Ok(groups)
127 }
128
129 fn apply_group_by_expressions(
130 &self,
131 view: DataView,
132 group_by_exprs: &[SqlExpression],
133 select_items: &[SelectItem],
134 having: Option<&SqlExpression>,
135 _case_insensitive: bool,
136 date_notation: String,
137 ) -> Result<(DataView, GroupByPhaseInfo)> {
138 use std::time::Instant;
139 let start = Instant::now();
140
141 debug!(
142 "apply_group_by_expressions - grouping by {} expressions, {} select items",
143 group_by_exprs.len(),
144 select_items.len()
145 );
146
147 let phase1_start = Instant::now();
149 let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
150 let phase1_time = phase1_start.elapsed();
151 debug!(
152 "apply_group_by_expressions Phase 1 (group building): {:?}, created {} groups",
153 phase1_time,
154 groups.len()
155 );
156
157 let mut result_table = DataTable::new("grouped_result");
159
160 let mut aggregate_columns = Vec::new();
162 let mut non_aggregate_exprs = Vec::new();
163 let mut derived_grouped_exprs: Vec<(SqlExpression, String)> = Vec::new();
167 let mut group_by_aliases = Vec::new();
168
169 for (i, group_expr) in group_by_exprs.iter().enumerate() {
171 let mut found_alias = None;
172
173 for item in select_items {
175 if let SelectItem::Expression { expr, alias, .. } = item {
176 if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
177 found_alias = Some(alias.clone());
178 break;
179 }
180 }
181 }
182
183 let alias = found_alias.unwrap_or_else(|| match group_expr {
185 SqlExpression::Column(column_ref) => column_ref.name.clone(),
186 _ => format!("group_expr_{}", i + 1),
187 });
188
189 result_table.add_column(DataColumn::new(&alias));
190 group_by_aliases.push(alias);
191 }
192
193 for item in select_items {
195 match item {
196 SelectItem::Expression { expr, alias, .. } => {
197 if contains_aggregate(expr) {
198 result_table.add_column(DataColumn::new(alias));
200 aggregate_columns.push((expr.clone(), alias.clone()));
201 } else {
202 let mut found = false;
204 for group_expr in group_by_exprs {
205 if expressions_match(expr, group_expr) {
206 found = true;
207 non_aggregate_exprs.push((expr.clone(), alias.clone()));
208 break;
209 }
210 }
211 if !found {
212 let mut referenced_cols = Vec::new();
220 collect_column_refs(expr, &mut referenced_cols);
221 let all_grouped = !referenced_cols.is_empty()
222 && referenced_cols.iter().all(|col_name| {
223 group_by_exprs
224 .iter()
225 .any(|ge| expression_references_column(ge, col_name))
226 });
227 if !all_grouped {
228 return Err(anyhow!(
229 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
230 alias
231 ));
232 }
233 result_table.add_column(DataColumn::new(alias));
236 derived_grouped_exprs.push((expr.clone(), alias.clone()));
237 }
238 }
239 }
240 SelectItem::Column {
241 column: col_ref, ..
242 } => {
243 let in_group_by = group_by_exprs.iter().any(
245 |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
246 );
247
248 if !in_group_by {
249 return Err(anyhow!(
250 "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
251 col_ref.name
252 ));
253 }
254 }
255 SelectItem::Star { .. } => {
256 }
259 SelectItem::StarExclude { .. } => {
260 }
263 }
264 }
265
266 let phase2_start = Instant::now();
268 let mut aggregation_time = std::time::Duration::ZERO;
269 let mut having_time = std::time::Duration::ZERO;
270 let mut groups_processed = 0;
271 let mut groups_filtered = 0;
272
273 for (group_key, group_view) in groups {
274 let mut row_values = Vec::new();
275
276 for value in &group_key.0 {
278 row_values.push(value.clone());
279 }
280
281 let agg_start = Instant::now();
283 for (expr, _col_name) in &aggregate_columns {
284 let group_rows = group_view.get_visible_rows();
285 let mut evaluator = ArithmeticEvaluator::with_date_notation(
286 group_view.source(),
287 date_notation.clone(),
288 )
289 .with_visible_rows(group_rows.clone());
290
291 let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
292 evaluator
293 .evaluate(expr, group_rows[0])
294 .unwrap_or(DataValue::Null)
295 } else {
296 DataValue::Null
297 };
298
299 row_values.push(value);
300 }
301
302 for (expr, _alias) in &derived_grouped_exprs {
306 let group_rows = group_view.get_visible_rows();
307 let value = if !group_rows.is_empty() {
308 let mut evaluator = ArithmeticEvaluator::with_date_notation(
309 group_view.source(),
310 date_notation.clone(),
311 );
312 evaluator
313 .evaluate(expr, group_rows[0])
314 .unwrap_or(DataValue::Null)
315 } else {
316 DataValue::Null
317 };
318 row_values.push(value);
319 }
320
321 aggregation_time += agg_start.elapsed();
322
323 let having_start = Instant::now();
325 if let Some(having_expr) = having {
326 let mut temp_table = DataTable::new("having_eval");
328
329 for alias in &group_by_aliases {
331 temp_table.add_column(DataColumn::new(alias));
332 }
333
334 for (_, alias) in &aggregate_columns {
336 temp_table.add_column(DataColumn::new(alias));
337 }
338
339 temp_table
340 .add_row(DataRow::new(row_values.clone()))
341 .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
342
343 let mut evaluator =
345 ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
346 let having_result = evaluator.evaluate(having_expr, 0)?;
347
348 if !is_truthy(&having_result) {
350 groups_filtered += 1;
351 having_time += having_start.elapsed();
352 continue;
353 }
354 }
355 having_time += having_start.elapsed();
356
357 groups_processed += 1;
358
359 result_table
361 .add_row(DataRow::new(row_values))
362 .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
363 }
364
365 let phase2_time = phase2_start.elapsed();
366 let total_time = start.elapsed();
367
368 debug!(
369 "apply_group_by_expressions Phase 2 (aggregation): {:?}",
370 phase2_time
371 );
372 debug!(" - Aggregation time: {:?}", aggregation_time);
373 debug!(" - HAVING evaluation time: {:?}", having_time);
374 debug!(
375 " - Groups processed: {}, filtered by HAVING: {}",
376 groups_processed, groups_filtered
377 );
378 debug!(
379 "apply_group_by_expressions Total time: {:?} (P1: {:?}, P2: {:?})",
380 total_time, phase1_time, phase2_time
381 );
382
383 let phase_info = GroupByPhaseInfo {
384 total_rows: view.row_count(),
385 num_groups: groups_processed,
386 num_expressions: group_by_exprs.len(),
387 phase1_cardinality_estimation: Duration::ZERO, phase2_key_building: phase1_time, phase2_expression_evaluation: Duration::ZERO, phase3_dataview_creation: Duration::ZERO, phase4_aggregation: aggregation_time,
392 phase4_having_evaluation: having_time,
393 groups_filtered_by_having: groups_filtered,
394 total_time,
395 };
396
397 Ok((DataView::new(Arc::new(result_table)), phase_info))
398 }
399}
400
401fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
403 format!("{:?}", expr1) == format!("{:?}", expr2)
406}
407
408fn collect_column_refs(expr: &SqlExpression, out: &mut Vec<String>) {
412 match expr {
413 SqlExpression::Column(col_ref) => out.push(col_ref.name.clone()),
414 SqlExpression::BinaryOp { left, right, .. } => {
415 collect_column_refs(left, out);
416 collect_column_refs(right, out);
417 }
418 SqlExpression::FunctionCall { args, .. } => {
419 for arg in args {
420 collect_column_refs(arg, out);
421 }
422 }
423 SqlExpression::Between { expr, lower, upper } => {
424 collect_column_refs(expr, out);
425 collect_column_refs(lower, out);
426 collect_column_refs(upper, out);
427 }
428 SqlExpression::Not { expr } => collect_column_refs(expr, out),
429 SqlExpression::CaseExpression {
430 when_branches,
431 else_branch,
432 } => {
433 for branch in when_branches {
434 collect_column_refs(&branch.condition, out);
435 collect_column_refs(&branch.result, out);
436 }
437 if let Some(else_expr) = else_branch {
438 collect_column_refs(else_expr, out);
439 }
440 }
441 SqlExpression::SimpleCaseExpression {
442 expr,
443 when_branches,
444 else_branch,
445 } => {
446 collect_column_refs(expr, out);
447 for branch in when_branches {
448 collect_column_refs(&branch.value, out);
449 collect_column_refs(&branch.result, out);
450 }
451 if let Some(else_expr) = else_branch {
452 collect_column_refs(else_expr, out);
453 }
454 }
455 _ => {}
457 }
458}
459
460fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
462 match expr {
463 SqlExpression::Column(name) => name == column,
464 SqlExpression::BinaryOp { left, right, .. } => {
465 expression_references_column(left, column)
466 || expression_references_column(right, column)
467 }
468 SqlExpression::FunctionCall { args, .. } => args
469 .iter()
470 .any(|arg| expression_references_column(arg, column)),
471 SqlExpression::Between { expr, lower, upper } => {
472 expression_references_column(expr, column)
473 || expression_references_column(lower, column)
474 || expression_references_column(upper, column)
475 }
476 _ => false,
477 }
478}
479
480fn is_truthy(value: &DataValue) -> bool {
482 match value {
483 DataValue::Boolean(b) => *b,
484 DataValue::Integer(i) => *i != 0,
485 DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
486 DataValue::Null => false,
487 _ => true,
488 }
489}