1use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
11use crate::data::query_engine::QueryEngine;
12use crate::sql::aggregates::contains_aggregate;
13use crate::sql::parser::ast::{SelectItem, SqlExpression};
14use tracing::debug;
15
16#[derive(Debug, Clone)]
18pub struct GroupByPhaseInfo {
19 pub total_rows: usize,
20 pub num_groups: usize,
21 pub num_expressions: usize,
22 pub phase1_cardinality_estimation: Duration,
23 pub phase2_key_building: Duration,
24 pub phase2_expression_evaluation: Duration,
25 pub phase3_dataview_creation: Duration,
26 pub phase4_aggregation: Duration,
27 pub phase4_having_evaluation: Duration,
28 pub groups_filtered_by_having: usize,
29 pub total_time: Duration,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct GroupKey(pub Vec<DataValue>);
35
36pub trait GroupByExpressions {
38 fn group_by_expressions(
40 &self,
41 view: DataView,
42 group_by_exprs: &[SqlExpression],
43 ) -> Result<FxHashMap<GroupKey, DataView>>;
44
45 fn apply_group_by_expressions(
47 &self,
48 view: DataView,
49 group_by_exprs: &[SqlExpression],
50 select_items: &[SelectItem],
51 having: Option<&SqlExpression>,
52 _case_insensitive: bool,
53 date_notation: String,
54 ) -> Result<(DataView, GroupByPhaseInfo)>;
55}
56
57impl GroupByExpressions for QueryEngine {
58 fn group_by_expressions(
59 &self,
60 view: DataView,
61 group_by_exprs: &[SqlExpression],
62 ) -> Result<FxHashMap<GroupKey, DataView>> {
63 use std::time::Instant;
64 let start = Instant::now();
65
66 let phase1_start = Instant::now();
68 let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
69 let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
70 let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
71 FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
72 let phase1_time = phase1_start.elapsed();
73 debug!(
74 "GROUP BY Phase 1 (cardinality estimation): {:?}, estimated {} groups",
75 phase1_time, estimated_groups
76 );
77
78 let phase2_start = Instant::now();
80 let visible_rows = view.get_visible_rows();
81 let total_rows = visible_rows.len();
82 debug!("GROUP BY Phase 2 starting: processing {} rows", total_rows);
83
84 let mut evaluator = ArithmeticEvaluator::new(view.source());
86
87 let mut key_values = Vec::with_capacity(group_by_exprs.len());
89
90 for row_idx in visible_rows.iter().copied() {
91 key_values.clear();
93
94 for expr in group_by_exprs {
96 let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
97 key_values.push(value);
98 }
99
100 let key = GroupKey(key_values.clone()); group_rows.entry(key).or_default().push(row_idx);
102 }
103 let phase2_time = phase2_start.elapsed();
104 debug!(
105 "GROUP BY Phase 2 (expression evaluation & key building): {:?}, created {} unique keys",
106 phase2_time,
107 group_rows.len()
108 );
109
110 let phase3_start = Instant::now();
112 for (key, rows) in group_rows {
113 let mut group_view = DataView::new(view.source_arc());
114 group_view = group_view.with_rows(rows);
115 groups.insert(key, group_view);
116 }
117 let phase3_time = phase3_start.elapsed();
118 debug!("GROUP BY Phase 3 (DataView creation): {:?}", phase3_time);
119
120 let total_time = start.elapsed();
121 debug!(
122 "GROUP BY Total time: {:?} (P1: {:?}, P2: {:?}, P3: {:?})",
123 total_time, phase1_time, phase2_time, phase3_time
124 );
125
126 Ok(groups)
127 }
128
129 fn apply_group_by_expressions(
130 &self,
131 view: DataView,
132 group_by_exprs: &[SqlExpression],
133 select_items: &[SelectItem],
134 having: Option<&SqlExpression>,
135 _case_insensitive: bool,
136 date_notation: String,
137 ) -> Result<(DataView, GroupByPhaseInfo)> {
138 use std::time::Instant;
139 let start = Instant::now();
140
141 debug!(
142 "apply_group_by_expressions - grouping by {} expressions, {} select items",
143 group_by_exprs.len(),
144 select_items.len()
145 );
146
147 let phase1_start = Instant::now();
149 let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
150 let phase1_time = phase1_start.elapsed();
151 debug!(
152 "apply_group_by_expressions Phase 1 (group building): {:?}, created {} groups",
153 phase1_time,
154 groups.len()
155 );
156
157 let mut result_table = DataTable::new("grouped_result");
159
160 let mut aggregate_columns = Vec::new();
162 let mut non_aggregate_exprs = Vec::new();
163 let mut group_by_aliases = Vec::new();
164
165 for (i, group_expr) in group_by_exprs.iter().enumerate() {
167 let mut found_alias = None;
168
169 for item in select_items {
171 if let SelectItem::Expression { expr, alias, .. } = item {
172 if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
173 found_alias = Some(alias.clone());
174 break;
175 }
176 }
177 }
178
179 let alias = found_alias.unwrap_or_else(|| match group_expr {
181 SqlExpression::Column(column_ref) => column_ref.name.clone(),
182 _ => format!("group_expr_{}", i + 1),
183 });
184
185 result_table.add_column(DataColumn::new(&alias));
186 group_by_aliases.push(alias);
187 }
188
189 for item in select_items {
191 match item {
192 SelectItem::Expression { expr, alias, .. } => {
193 if contains_aggregate(expr) {
194 result_table.add_column(DataColumn::new(alias));
196 aggregate_columns.push((expr.clone(), alias.clone()));
197 } else {
198 let mut found = false;
200 for group_expr in group_by_exprs {
201 if expressions_match(expr, group_expr) {
202 found = true;
203 non_aggregate_exprs.push((expr.clone(), alias.clone()));
204 break;
205 }
206 }
207 if !found {
208 if let SqlExpression::Column(col) = expr {
210 let referenced = group_by_exprs
212 .iter()
213 .any(|ge| expression_references_column(ge, &col.name));
214 if !referenced {
215 return Err(anyhow!(
216 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
217 alias
218 ));
219 }
220 } else {
221 return Err(anyhow!(
222 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
223 alias
224 ));
225 }
226 }
227 }
228 }
229 SelectItem::Column {
230 column: col_ref, ..
231 } => {
232 let in_group_by = group_by_exprs.iter().any(
234 |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
235 );
236
237 if !in_group_by {
238 return Err(anyhow!(
239 "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
240 col_ref.name
241 ));
242 }
243 }
244 SelectItem::Star { .. } => {
245 }
248 }
249 }
250
251 let phase2_start = Instant::now();
253 let mut aggregation_time = std::time::Duration::ZERO;
254 let mut having_time = std::time::Duration::ZERO;
255 let mut groups_processed = 0;
256 let mut groups_filtered = 0;
257
258 for (group_key, group_view) in groups {
259 let mut row_values = Vec::new();
260
261 for value in &group_key.0 {
263 row_values.push(value.clone());
264 }
265
266 let agg_start = Instant::now();
268 for (expr, _col_name) in &aggregate_columns {
269 let group_rows = group_view.get_visible_rows();
270 let mut evaluator = ArithmeticEvaluator::with_date_notation(
271 group_view.source(),
272 date_notation.clone(),
273 )
274 .with_visible_rows(group_rows.clone());
275
276 let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
277 evaluator
278 .evaluate(expr, group_rows[0])
279 .unwrap_or(DataValue::Null)
280 } else {
281 DataValue::Null
282 };
283
284 row_values.push(value);
285 }
286 aggregation_time += agg_start.elapsed();
287
288 let having_start = Instant::now();
290 if let Some(having_expr) = having {
291 let mut temp_table = DataTable::new("having_eval");
293
294 for alias in &group_by_aliases {
296 temp_table.add_column(DataColumn::new(alias));
297 }
298
299 for (_, alias) in &aggregate_columns {
301 temp_table.add_column(DataColumn::new(alias));
302 }
303
304 temp_table
305 .add_row(DataRow::new(row_values.clone()))
306 .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
307
308 let mut evaluator =
310 ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
311 let having_result = evaluator.evaluate(having_expr, 0)?;
312
313 if !is_truthy(&having_result) {
315 groups_filtered += 1;
316 having_time += having_start.elapsed();
317 continue;
318 }
319 }
320 having_time += having_start.elapsed();
321
322 groups_processed += 1;
323
324 result_table
326 .add_row(DataRow::new(row_values))
327 .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
328 }
329
330 let phase2_time = phase2_start.elapsed();
331 let total_time = start.elapsed();
332
333 debug!(
334 "apply_group_by_expressions Phase 2 (aggregation): {:?}",
335 phase2_time
336 );
337 debug!(" - Aggregation time: {:?}", aggregation_time);
338 debug!(" - HAVING evaluation time: {:?}", having_time);
339 debug!(
340 " - Groups processed: {}, filtered by HAVING: {}",
341 groups_processed, groups_filtered
342 );
343 debug!(
344 "apply_group_by_expressions Total time: {:?} (P1: {:?}, P2: {:?})",
345 total_time, phase1_time, phase2_time
346 );
347
348 let phase_info = GroupByPhaseInfo {
349 total_rows: view.row_count(),
350 num_groups: groups_processed,
351 num_expressions: group_by_exprs.len(),
352 phase1_cardinality_estimation: Duration::ZERO, phase2_key_building: phase1_time, phase2_expression_evaluation: Duration::ZERO, phase3_dataview_creation: Duration::ZERO, phase4_aggregation: aggregation_time,
357 phase4_having_evaluation: having_time,
358 groups_filtered_by_having: groups_filtered,
359 total_time,
360 };
361
362 Ok((DataView::new(Arc::new(result_table)), phase_info))
363 }
364}
365
366fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
368 format!("{:?}", expr1) == format!("{:?}", expr2)
371}
372
373fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
375 match expr {
376 SqlExpression::Column(name) => name == column,
377 SqlExpression::BinaryOp { left, right, .. } => {
378 expression_references_column(left, column)
379 || expression_references_column(right, column)
380 }
381 SqlExpression::FunctionCall { args, .. } => args
382 .iter()
383 .any(|arg| expression_references_column(arg, column)),
384 SqlExpression::Between { expr, lower, upper } => {
385 expression_references_column(expr, column)
386 || expression_references_column(lower, column)
387 || expression_references_column(upper, column)
388 }
389 _ => false,
390 }
391}
392
393fn is_truthy(value: &DataValue) -> bool {
395 match value {
396 DataValue::Boolean(b) => *b,
397 DataValue::Integer(i) => *i != 0,
398 DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
399 DataValue::Null => false,
400 _ => true,
401 }
402}