1use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8
9use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
10use crate::data::data_view::DataView;
11use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
12use crate::data::query_engine::QueryEngine;
13use crate::sql::aggregates::contains_aggregate;
14use crate::sql::parser::ast::{ColumnRef, SelectItem, SqlExpression};
15use tracing::debug;
16
17#[derive(Debug, Clone)]
19pub struct GroupByPhaseInfo {
20 pub total_rows: usize,
21 pub num_groups: usize,
22 pub num_expressions: usize,
23 pub phase1_cardinality_estimation: Duration,
24 pub phase2_key_building: Duration,
25 pub phase2_expression_evaluation: Duration,
26 pub phase3_dataview_creation: Duration,
27 pub phase4_aggregation: Duration,
28 pub phase4_having_evaluation: Duration,
29 pub groups_filtered_by_having: usize,
30 pub total_time: Duration,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub struct GroupKey(pub Vec<DataValue>);
36
37pub trait GroupByExpressions {
39 fn group_by_expressions(
41 &self,
42 view: DataView,
43 group_by_exprs: &[SqlExpression],
44 ) -> Result<FxHashMap<GroupKey, DataView>>;
45
46 fn apply_group_by_expressions(
48 &self,
49 view: DataView,
50 group_by_exprs: &[SqlExpression],
51 select_items: &[SelectItem],
52 having: Option<&SqlExpression>,
53 _case_insensitive: bool,
54 date_notation: String,
55 ) -> Result<(DataView, GroupByPhaseInfo)>;
56}
57
58impl GroupByExpressions for QueryEngine {
59 fn group_by_expressions(
60 &self,
61 view: DataView,
62 group_by_exprs: &[SqlExpression],
63 ) -> Result<FxHashMap<GroupKey, DataView>> {
64 use std::time::Instant;
65 let start = Instant::now();
66
67 let phase1_start = Instant::now();
69 let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
70 let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
71 let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
72 FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
73 let phase1_time = phase1_start.elapsed();
74 debug!(
75 "GROUP BY Phase 1 (cardinality estimation): {:?}, estimated {} groups",
76 phase1_time, estimated_groups
77 );
78
79 let phase2_start = Instant::now();
81 let visible_rows = view.get_visible_rows();
82 let total_rows = visible_rows.len();
83 debug!("GROUP BY Phase 2 starting: processing {} rows", total_rows);
84
85 let mut evaluator = ArithmeticEvaluator::new(view.source());
87
88 let mut key_values = Vec::with_capacity(group_by_exprs.len());
90
91 for row_idx in visible_rows.iter().copied() {
92 key_values.clear();
94
95 for expr in group_by_exprs {
97 let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
98 key_values.push(value);
99 }
100
101 let key = GroupKey(key_values.clone()); group_rows.entry(key).or_default().push(row_idx);
103 }
104 let phase2_time = phase2_start.elapsed();
105 debug!(
106 "GROUP BY Phase 2 (expression evaluation & key building): {:?}, created {} unique keys",
107 phase2_time,
108 group_rows.len()
109 );
110
111 let phase3_start = Instant::now();
113 for (key, rows) in group_rows {
114 let mut group_view = DataView::new(view.source_arc());
115 group_view = group_view.with_rows(rows);
116 groups.insert(key, group_view);
117 }
118 let phase3_time = phase3_start.elapsed();
119 debug!("GROUP BY Phase 3 (DataView creation): {:?}", phase3_time);
120
121 let total_time = start.elapsed();
122 debug!(
123 "GROUP BY Total time: {:?} (P1: {:?}, P2: {:?}, P3: {:?})",
124 total_time, phase1_time, phase2_time, phase3_time
125 );
126
127 Ok(groups)
128 }
129
130 fn apply_group_by_expressions(
131 &self,
132 view: DataView,
133 group_by_exprs: &[SqlExpression],
134 select_items: &[SelectItem],
135 having: Option<&SqlExpression>,
136 _case_insensitive: bool,
137 date_notation: String,
138 ) -> Result<(DataView, GroupByPhaseInfo)> {
139 use std::time::Instant;
140 let start = Instant::now();
141
142 debug!(
143 "apply_group_by_expressions - grouping by {} expressions, {} select items",
144 group_by_exprs.len(),
145 select_items.len()
146 );
147
148 let phase1_start = Instant::now();
150 let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
151 let phase1_time = phase1_start.elapsed();
152 debug!(
153 "apply_group_by_expressions Phase 1 (group building): {:?}, created {} groups",
154 phase1_time,
155 groups.len()
156 );
157
158 let mut result_table = DataTable::new("grouped_result");
160
161 let mut aggregate_columns = Vec::new();
163 let mut non_aggregate_exprs = Vec::new();
164 let mut group_by_aliases = Vec::new();
165
166 for (i, group_expr) in group_by_exprs.iter().enumerate() {
168 let mut found_alias = None;
169
170 for item in select_items {
172 if let SelectItem::Expression { expr, alias } = item {
173 if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
174 found_alias = Some(alias.clone());
175 break;
176 }
177 }
178 }
179
180 let alias = found_alias.unwrap_or_else(|| match group_expr {
182 SqlExpression::Column(column_ref) => column_ref.name.clone(),
183 _ => format!("group_expr_{}", i + 1),
184 });
185
186 result_table.add_column(DataColumn::new(&alias));
187 group_by_aliases.push(alias);
188 }
189
190 for item in select_items {
192 match item {
193 SelectItem::Expression { expr, alias } => {
194 if contains_aggregate(expr) {
195 result_table.add_column(DataColumn::new(alias));
197 aggregate_columns.push((expr.clone(), alias.clone()));
198 } else {
199 let mut found = false;
201 for group_expr in group_by_exprs {
202 if expressions_match(expr, group_expr) {
203 found = true;
204 non_aggregate_exprs.push((expr.clone(), alias.clone()));
205 break;
206 }
207 }
208 if !found {
209 if let SqlExpression::Column(col) = expr {
211 let referenced = group_by_exprs
213 .iter()
214 .any(|ge| expression_references_column(ge, &col.name));
215 if !referenced {
216 return Err(anyhow!(
217 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
218 alias
219 ));
220 }
221 } else {
222 return Err(anyhow!(
223 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
224 alias
225 ));
226 }
227 }
228 }
229 }
230 SelectItem::Column(col_ref) => {
231 let in_group_by = group_by_exprs.iter().any(
233 |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
234 );
235
236 if !in_group_by {
237 return Err(anyhow!(
238 "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
239 col_ref.name
240 ));
241 }
242 }
243 SelectItem::Star => {
244 }
247 }
248 }
249
250 let phase2_start = Instant::now();
252 let mut aggregation_time = std::time::Duration::ZERO;
253 let mut having_time = std::time::Duration::ZERO;
254 let mut groups_processed = 0;
255 let mut groups_filtered = 0;
256
257 for (group_key, group_view) in groups {
258 let mut row_values = Vec::new();
259
260 for value in &group_key.0 {
262 row_values.push(value.clone());
263 }
264
265 let agg_start = Instant::now();
267 for (expr, _col_name) in &aggregate_columns {
268 let group_rows = group_view.get_visible_rows();
269 let mut evaluator = ArithmeticEvaluator::with_date_notation(
270 group_view.source(),
271 date_notation.clone(),
272 )
273 .with_visible_rows(group_rows.clone());
274
275 let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
276 evaluator
277 .evaluate(expr, group_rows[0])
278 .unwrap_or(DataValue::Null)
279 } else {
280 DataValue::Null
281 };
282
283 row_values.push(value);
284 }
285 aggregation_time += agg_start.elapsed();
286
287 let having_start = Instant::now();
289 if let Some(having_expr) = having {
290 let mut temp_table = DataTable::new("having_eval");
292
293 for alias in &group_by_aliases {
295 temp_table.add_column(DataColumn::new(alias));
296 }
297
298 for (_, alias) in &aggregate_columns {
300 temp_table.add_column(DataColumn::new(alias));
301 }
302
303 temp_table
304 .add_row(DataRow::new(row_values.clone()))
305 .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
306
307 let mut evaluator =
309 ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
310 let having_result = evaluator.evaluate(having_expr, 0)?;
311
312 if !is_truthy(&having_result) {
314 groups_filtered += 1;
315 having_time += having_start.elapsed();
316 continue;
317 }
318 }
319 having_time += having_start.elapsed();
320
321 groups_processed += 1;
322
323 result_table
325 .add_row(DataRow::new(row_values))
326 .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
327 }
328
329 let phase2_time = phase2_start.elapsed();
330 let total_time = start.elapsed();
331
332 debug!(
333 "apply_group_by_expressions Phase 2 (aggregation): {:?}",
334 phase2_time
335 );
336 debug!(" - Aggregation time: {:?}", aggregation_time);
337 debug!(" - HAVING evaluation time: {:?}", having_time);
338 debug!(
339 " - Groups processed: {}, filtered by HAVING: {}",
340 groups_processed, groups_filtered
341 );
342 debug!(
343 "apply_group_by_expressions Total time: {:?} (P1: {:?}, P2: {:?})",
344 total_time, phase1_time, phase2_time
345 );
346
347 let phase_info = GroupByPhaseInfo {
348 total_rows: view.row_count(),
349 num_groups: groups_processed,
350 num_expressions: group_by_exprs.len(),
351 phase1_cardinality_estimation: Duration::ZERO, phase2_key_building: phase1_time, phase2_expression_evaluation: Duration::ZERO, phase3_dataview_creation: Duration::ZERO, phase4_aggregation: aggregation_time,
356 phase4_having_evaluation: having_time,
357 groups_filtered_by_having: groups_filtered,
358 total_time,
359 };
360
361 Ok((DataView::new(Arc::new(result_table)), phase_info))
362 }
363}
364
365fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
367 format!("{:?}", expr1) == format!("{:?}", expr2)
370}
371
372fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
374 match expr {
375 SqlExpression::Column(name) => name == column,
376 SqlExpression::BinaryOp { left, right, .. } => {
377 expression_references_column(left, column)
378 || expression_references_column(right, column)
379 }
380 SqlExpression::FunctionCall { args, .. } => args
381 .iter()
382 .any(|arg| expression_references_column(arg, column)),
383 SqlExpression::Between { expr, lower, upper } => {
384 expression_references_column(expr, column)
385 || expression_references_column(lower, column)
386 || expression_references_column(upper, column)
387 }
388 _ => false,
389 }
390}
391
392fn is_truthy(value: &DataValue) -> bool {
394 match value {
395 DataValue::Boolean(b) => *b,
396 DataValue::Integer(i) => *i != 0,
397 DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
398 DataValue::Null => false,
399 _ => true,
400 }
401}