1use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
11use crate::data::query_engine::QueryEngine;
12use crate::sql::aggregates::contains_aggregate;
13use crate::sql::parser::ast::{SelectItem, SqlExpression};
14use tracing::debug;
15
16#[derive(Debug, Clone)]
18pub struct GroupByPhaseInfo {
19 pub total_rows: usize,
20 pub num_groups: usize,
21 pub num_expressions: usize,
22 pub phase1_cardinality_estimation: Duration,
23 pub phase2_key_building: Duration,
24 pub phase2_expression_evaluation: Duration,
25 pub phase3_dataview_creation: Duration,
26 pub phase4_aggregation: Duration,
27 pub phase4_having_evaluation: Duration,
28 pub groups_filtered_by_having: usize,
29 pub total_time: Duration,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct GroupKey(pub Vec<DataValue>);
35
36pub trait GroupByExpressions {
38 fn group_by_expressions(
40 &self,
41 view: DataView,
42 group_by_exprs: &[SqlExpression],
43 ) -> Result<FxHashMap<GroupKey, DataView>>;
44
45 fn apply_group_by_expressions(
47 &self,
48 view: DataView,
49 group_by_exprs: &[SqlExpression],
50 select_items: &[SelectItem],
51 having: Option<&SqlExpression>,
52 _case_insensitive: bool,
53 date_notation: String,
54 ) -> Result<(DataView, GroupByPhaseInfo)>;
55}
56
57impl GroupByExpressions for QueryEngine {
58 fn group_by_expressions(
59 &self,
60 view: DataView,
61 group_by_exprs: &[SqlExpression],
62 ) -> Result<FxHashMap<GroupKey, DataView>> {
63 use std::time::Instant;
64 let start = Instant::now();
65
66 let phase1_start = Instant::now();
68 let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
69 let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
70 let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
71 FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
72 let phase1_time = phase1_start.elapsed();
73 debug!(
74 "GROUP BY Phase 1 (cardinality estimation): {:?}, estimated {} groups",
75 phase1_time, estimated_groups
76 );
77
78 let phase2_start = Instant::now();
80 let visible_rows = view.get_visible_rows();
81 let total_rows = visible_rows.len();
82 debug!("GROUP BY Phase 2 starting: processing {} rows", total_rows);
83
84 let mut evaluator = ArithmeticEvaluator::new(view.source());
86
87 let mut key_values = Vec::with_capacity(group_by_exprs.len());
89
90 for row_idx in visible_rows.iter().copied() {
91 key_values.clear();
93
94 for expr in group_by_exprs {
96 let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
97 key_values.push(value);
98 }
99
100 let key = GroupKey(key_values.clone()); group_rows.entry(key).or_default().push(row_idx);
102 }
103 let phase2_time = phase2_start.elapsed();
104 debug!(
105 "GROUP BY Phase 2 (expression evaluation & key building): {:?}, created {} unique keys",
106 phase2_time,
107 group_rows.len()
108 );
109
110 let phase3_start = Instant::now();
112 for (key, rows) in group_rows {
113 let mut group_view = DataView::new(view.source_arc());
114 group_view = group_view.with_rows(rows);
115 groups.insert(key, group_view);
116 }
117 let phase3_time = phase3_start.elapsed();
118 debug!("GROUP BY Phase 3 (DataView creation): {:?}", phase3_time);
119
120 let total_time = start.elapsed();
121 debug!(
122 "GROUP BY Total time: {:?} (P1: {:?}, P2: {:?}, P3: {:?})",
123 total_time, phase1_time, phase2_time, phase3_time
124 );
125
126 Ok(groups)
127 }
128
129 fn apply_group_by_expressions(
130 &self,
131 view: DataView,
132 group_by_exprs: &[SqlExpression],
133 select_items: &[SelectItem],
134 having: Option<&SqlExpression>,
135 _case_insensitive: bool,
136 date_notation: String,
137 ) -> Result<(DataView, GroupByPhaseInfo)> {
138 use std::time::Instant;
139 let start = Instant::now();
140
141 debug!(
142 "apply_group_by_expressions - grouping by {} expressions, {} select items",
143 group_by_exprs.len(),
144 select_items.len()
145 );
146
147 let phase1_start = Instant::now();
149 let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
150 let phase1_time = phase1_start.elapsed();
151 debug!(
152 "apply_group_by_expressions Phase 1 (group building): {:?}, created {} groups",
153 phase1_time,
154 groups.len()
155 );
156
157 let mut result_table = DataTable::new("grouped_result");
159
160 let mut aggregate_columns = Vec::new();
162 let mut non_aggregate_exprs = Vec::new();
163 let mut group_by_aliases = Vec::new();
164
165 for (i, group_expr) in group_by_exprs.iter().enumerate() {
167 let mut found_alias = None;
168
169 for item in select_items {
171 if let SelectItem::Expression { expr, alias } = item {
172 if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
173 found_alias = Some(alias.clone());
174 break;
175 }
176 }
177 }
178
179 let alias = found_alias.unwrap_or_else(|| match group_expr {
181 SqlExpression::Column(column_ref) => column_ref.name.clone(),
182 _ => format!("group_expr_{}", i + 1),
183 });
184
185 result_table.add_column(DataColumn::new(&alias));
186 group_by_aliases.push(alias);
187 }
188
189 for item in select_items {
191 match item {
192 SelectItem::Expression { expr, alias } => {
193 if contains_aggregate(expr) {
194 result_table.add_column(DataColumn::new(alias));
196 aggregate_columns.push((expr.clone(), alias.clone()));
197 } else {
198 let mut found = false;
200 for group_expr in group_by_exprs {
201 if expressions_match(expr, group_expr) {
202 found = true;
203 non_aggregate_exprs.push((expr.clone(), alias.clone()));
204 break;
205 }
206 }
207 if !found {
208 if let SqlExpression::Column(col) = expr {
210 let referenced = group_by_exprs
212 .iter()
213 .any(|ge| expression_references_column(ge, &col.name));
214 if !referenced {
215 return Err(anyhow!(
216 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
217 alias
218 ));
219 }
220 } else {
221 return Err(anyhow!(
222 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
223 alias
224 ));
225 }
226 }
227 }
228 }
229 SelectItem::Column(col_ref) => {
230 let in_group_by = group_by_exprs.iter().any(
232 |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
233 );
234
235 if !in_group_by {
236 return Err(anyhow!(
237 "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
238 col_ref.name
239 ));
240 }
241 }
242 SelectItem::Star => {
243 }
246 }
247 }
248
249 let phase2_start = Instant::now();
251 let mut aggregation_time = std::time::Duration::ZERO;
252 let mut having_time = std::time::Duration::ZERO;
253 let mut groups_processed = 0;
254 let mut groups_filtered = 0;
255
256 for (group_key, group_view) in groups {
257 let mut row_values = Vec::new();
258
259 for value in &group_key.0 {
261 row_values.push(value.clone());
262 }
263
264 let agg_start = Instant::now();
266 for (expr, _col_name) in &aggregate_columns {
267 let group_rows = group_view.get_visible_rows();
268 let mut evaluator = ArithmeticEvaluator::with_date_notation(
269 group_view.source(),
270 date_notation.clone(),
271 )
272 .with_visible_rows(group_rows.clone());
273
274 let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
275 evaluator
276 .evaluate(expr, group_rows[0])
277 .unwrap_or(DataValue::Null)
278 } else {
279 DataValue::Null
280 };
281
282 row_values.push(value);
283 }
284 aggregation_time += agg_start.elapsed();
285
286 let having_start = Instant::now();
288 if let Some(having_expr) = having {
289 let mut temp_table = DataTable::new("having_eval");
291
292 for alias in &group_by_aliases {
294 temp_table.add_column(DataColumn::new(alias));
295 }
296
297 for (_, alias) in &aggregate_columns {
299 temp_table.add_column(DataColumn::new(alias));
300 }
301
302 temp_table
303 .add_row(DataRow::new(row_values.clone()))
304 .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
305
306 let mut evaluator =
308 ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
309 let having_result = evaluator.evaluate(having_expr, 0)?;
310
311 if !is_truthy(&having_result) {
313 groups_filtered += 1;
314 having_time += having_start.elapsed();
315 continue;
316 }
317 }
318 having_time += having_start.elapsed();
319
320 groups_processed += 1;
321
322 result_table
324 .add_row(DataRow::new(row_values))
325 .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
326 }
327
328 let phase2_time = phase2_start.elapsed();
329 let total_time = start.elapsed();
330
331 debug!(
332 "apply_group_by_expressions Phase 2 (aggregation): {:?}",
333 phase2_time
334 );
335 debug!(" - Aggregation time: {:?}", aggregation_time);
336 debug!(" - HAVING evaluation time: {:?}", having_time);
337 debug!(
338 " - Groups processed: {}, filtered by HAVING: {}",
339 groups_processed, groups_filtered
340 );
341 debug!(
342 "apply_group_by_expressions Total time: {:?} (P1: {:?}, P2: {:?})",
343 total_time, phase1_time, phase2_time
344 );
345
346 let phase_info = GroupByPhaseInfo {
347 total_rows: view.row_count(),
348 num_groups: groups_processed,
349 num_expressions: group_by_exprs.len(),
350 phase1_cardinality_estimation: Duration::ZERO, phase2_key_building: phase1_time, phase2_expression_evaluation: Duration::ZERO, phase3_dataview_creation: Duration::ZERO, phase4_aggregation: aggregation_time,
355 phase4_having_evaluation: having_time,
356 groups_filtered_by_having: groups_filtered,
357 total_time,
358 };
359
360 Ok((DataView::new(Arc::new(result_table)), phase_info))
361 }
362}
363
364fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
366 format!("{:?}", expr1) == format!("{:?}", expr2)
369}
370
371fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
373 match expr {
374 SqlExpression::Column(name) => name == column,
375 SqlExpression::BinaryOp { left, right, .. } => {
376 expression_references_column(left, column)
377 || expression_references_column(right, column)
378 }
379 SqlExpression::FunctionCall { args, .. } => args
380 .iter()
381 .any(|arg| expression_references_column(arg, column)),
382 SqlExpression::Between { expr, lower, upper } => {
383 expression_references_column(expr, column)
384 || expression_references_column(lower, column)
385 || expression_references_column(upper, column)
386 }
387 _ => false,
388 }
389}
390
391fn is_truthy(value: &DataValue) -> bool {
393 match value {
394 DataValue::Boolean(b) => *b,
395 DataValue::Integer(i) => *i != 0,
396 DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
397 DataValue::Null => false,
398 _ => true,
399 }
400}